chore(deps): bump Go 1.21→1.24 and resync vendor for Pion WebRTC v4 compat

Pion webrtc/v4 (v4.2.11) requires Go 1.24+. Upstream datarhei was at
go 1.21.0. Bumping to go 1.24.0 pulls minor bumps across testify,
golang.org/x/{crypto,net,sync,sys,text,time,tools,mod}; vendor/ is
regenerated via 'go mod vendor' to reflect the new versions.

No application code changes; pure dep bump to unblock M1.
This commit is contained in:
Zac Gaetano 2026-04-17 08:43:31 -04:00
parent 262a393b8d
commit 651a9a3eb5
385 changed files with 55239 additions and 79639 deletions

22
go.mod
View file

@ -1,8 +1,6 @@
module github.com/datarhei/core/v16
go 1.21.0
toolchain go1.22.1
go 1.24.0
require (
github.com/99designs/gqlgen v0.17.47
@ -26,13 +24,13 @@ require (
github.com/prometheus/client_golang v1.19.1
github.com/puzpuzpuz/xsync/v3 v3.1.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.11.1
github.com/swaggo/echo-swagger v1.4.1
github.com/swaggo/swag v1.16.3
github.com/vektah/gqlparser/v2 v2.5.12
github.com/xeipuuv/gojsonschema v1.2.0
go.uber.org/zap v1.27.0
golang.org/x/mod v0.17.0
golang.org/x/mod v0.32.0
)
require (
@ -94,13 +92,13 @@ require (
github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.10.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect

35
go.sum
View file

@ -177,8 +177,9 @@ github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/swaggo/echo-swagger v1.4.1 h1:Yf0uPaJWp1uRtDloZALyLnvdBeoEL5Kc7DtnjzO/TUk=
github.com/swaggo/echo-swagger v1.4.1/go.mod h1:C8bSi+9yH2FLZsnhqMZLIZddpUxZdBYuNHbtaS1Hljc=
github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw=
@ -222,14 +223,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -239,14 +240,14 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View file

@ -7,10 +7,13 @@ import (
"time"
)
type CompareType int
// Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it.
type CompareType = compareResult
type compareResult int
const (
compareLess CompareType = iota - 1
compareLess compareResult = iota - 1
compareEqual
compareGreater
)
@ -39,7 +42,7 @@ var (
bytesType = reflect.TypeOf([]byte{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) {
obj1Value := reflect.ValueOf(obj1)
obj2Value := reflect.ValueOf(obj2)
@ -325,7 +328,13 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
if timeObj1.Before(timeObj2) {
return compareLess, true
}
if timeObj1.Equal(timeObj2) {
return compareEqual, true
}
return compareGreater, true
}
case reflect.Slice:
{
@ -345,7 +354,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true
}
case reflect.Uintptr:
{
@ -381,7 +390,8 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not greater than \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, failMessage, msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
@ -394,7 +404,8 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not greater than or equal to \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, failMessage, msgAndArgs...)
}
// Less asserts that the first element is less than the second
@ -406,7 +417,8 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not less than \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareLess}, failMessage, msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
@ -419,7 +431,8 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not less than or equal to \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, failMessage, msgAndArgs...)
}
// Positive asserts that the specified element is positive
@ -431,7 +444,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not positive", e)
return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, failMessage, msgAndArgs...)
}
// Negative asserts that the specified element is negative
@ -443,10 +457,11 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
failMessage := fmt.Sprintf("\"%v\" is not negative", e)
return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, failMessage, msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
@ -459,17 +474,17 @@ func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedCompare
compareResult, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
return Fail(t, fmt.Sprintf(`Can not compare type "%T"`, e1), msgAndArgs...)
}
if !containsValue(allowedComparesResults, compareResult) {
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
return Fail(t, failMessage, msgAndArgs...)
}
return true
}
func containsValue(values []CompareType, value CompareType) bool {
func containsValue(values []compareResult, value compareResult) bool {
for _, v := range values {
if v == value {
return true

View file

@ -50,10 +50,19 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
}
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// Emptyf asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// assert.Emptyf(t, obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -104,8 +113,8 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValuesf asserts that two objects are equal or convertible to the larger
// type and equal.
//
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
@ -117,10 +126,8 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if assert.Errorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
// actualObj, err := SomeFunction()
// assert.Errorf(t, err, "error message %s", "formatted")
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -186,7 +193,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick
// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -438,7 +445,19 @@ func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interf
return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...)
}
// IsNotTypef asserts that the specified objects are not of the same type.
//
// assert.IsNotTypef(t, &NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func IsNotTypef(t TestingT, theType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsNotType(t, theType, object, append([]interface{}{msg}, args...)...)
}
// IsTypef asserts that the specified objects are of the same type.
//
// assert.IsTypef(t, &MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -568,8 +587,24 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
}
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// assert.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
}
// NotEmptyf asserts that the specified object is NOT [Empty].
//
// if assert.NotEmptyf(t, obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
@ -604,7 +639,16 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
// NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@ -667,12 +711,15 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// assert.NotSubsetf(t, [1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -756,11 +803,15 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subsetf asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// assert.Subsetf(t, [1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()

View file

@ -92,10 +92,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st
return ElementsMatchf(a.t, listA, listB, msg, args...)
}
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// Empty asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// a.Empty(obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -103,10 +112,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
return Empty(a.t, object, msgAndArgs...)
}
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// Emptyf asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// a.Emptyf(obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -186,8 +204,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...)
}
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValues asserts that two objects are equal or convertible to the larger
// type and equal.
//
// a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
@ -197,8 +215,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...)
}
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValuesf asserts that two objects are equal or convertible to the larger
// type and equal.
//
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
@ -224,10 +242,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string
// Error asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Error(err) {
// assert.Equal(t, expectedError, err)
// }
// actualObj, err := SomeFunction()
// a.Error(err)
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -297,10 +313,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Errorf(err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
// actualObj, err := SomeFunction()
// a.Errorf(err, "error message %s", "formatted")
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -336,7 +350,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti
// a.EventuallyWithT(func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -361,7 +375,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor
// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -868,7 +882,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in
return IsNonIncreasingf(a.t, object, msg, args...)
}
// IsNotType asserts that the specified objects are not of the same type.
//
// a.IsNotType(&NotMyStruct{}, &MyStruct{})
func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNotType(a.t, theType, object, msgAndArgs...)
}
// IsNotTypef asserts that the specified objects are not of the same type.
//
// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNotTypef(a.t, theType, object, msg, args...)
}
// IsType asserts that the specified objects are of the same type.
//
// a.IsType(&MyStruct{}, &MyStruct{})
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -877,6 +913,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd
}
// IsTypef asserts that the specified objects are of the same type.
//
// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1128,8 +1166,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin
return NotContainsf(a.t, s, contains, msg, args...)
}
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true
//
// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true
func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotElementsMatch(a.t, listA, listB, msgAndArgs...)
}
// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotElementsMatchf(a.t, listA, listB, msg, args...)
}
// NotEmpty asserts that the specified object is NOT [Empty].
//
// if a.NotEmpty(obj) {
// assert.Equal(t, "two", obj[1])
@ -1141,8 +1212,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) boo
return NotEmpty(a.t, object, msgAndArgs...)
}
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotEmptyf asserts that the specified object is NOT [Empty].
//
// if a.NotEmptyf(obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
@ -1200,7 +1270,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
return NotEqualf(a.t, expected, actual, msg, args...)
}
// NotErrorIs asserts that at none of the errors in err's chain matches target.
// NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorAs(a.t, err, target, msgAndArgs...)
}
// NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorAsf(a.t, err, target, msg, args...)
}
// NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@ -1209,7 +1297,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface
return NotErrorIs(a.t, err, target, msgAndArgs...)
}
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
// NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@ -1326,12 +1414,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubset asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
// a.NotSubset([1, 3, 4], {1: "one", 2: "two"})
// a.NotSubset({"x": 1, "y": 2}, ["z"])
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1339,12 +1430,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...)
}
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1504,11 +1598,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...)
}
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subset asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
// a.Subset([1, 2, 3], {1: "one", 2: "two"})
// a.Subset({"x": 1, "y": 2}, ["x"])
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1516,11 +1614,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...)
}
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subsetf asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()

View file

@ -6,7 +6,7 @@ import (
)
// isOrdered checks that collection contains orderable elements.
func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
func isOrdered(t TestingT, object interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
objKind := reflect.TypeOf(object).Kind()
if objKind != reflect.Slice && objKind != reflect.Array {
return false
@ -33,7 +33,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...)
return Fail(t, fmt.Sprintf(`Can not compare type "%T" and "%T"`, value, prevValue), msgAndArgs...)
}
if !containsValue(allowedComparesResults, compareResult) {
@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
return isOrdered(t, object, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
return isOrdered(t, object, []compareResult{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
return isOrdered(t, object, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
return isOrdered(t, object, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View file

@ -19,7 +19,9 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
"gopkg.in/yaml.v3"
// Wrapper around gopkg.in/yaml.v3
"github.com/stretchr/testify/assert/yaml"
)
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -45,6 +47,10 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// PanicAssertionFunc is a common function prototype when validating a panic value. Can be useful
// for table driven tests.
type PanicAssertionFunc = func(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool
// Comparison is a custom function that returns true on success and false on failure
type Comparison func() (success bool)
@ -204,59 +210,77 @@ the problem actually occurred in calling code.*/
// of each stack frame leading from the current test to the assert call that
// failed.
func CallerInfo() []string {
var pc uintptr
var ok bool
var file string
var line int
var name string
const stackFrameBufferSize = 10
pcs := make([]uintptr, stackFrameBufferSize)
callers := []string{}
for i := 0; ; i++ {
pc, file, line, ok = runtime.Caller(i)
if !ok {
// The breaks below failed to terminate the loop, and we ran off the
// end of the call stack.
offset := 1
for {
n := runtime.Callers(offset, pcs)
if n == 0 {
break
}
// This is a huge edge case, but it will panic if this is the case, see #180
if file == "<autogenerated>" {
break
}
frames := runtime.CallersFrames(pcs[:n])
f := runtime.FuncForPC(pc)
if f == nil {
break
}
name = f.Name()
for {
frame, more := frames.Next()
pc = frame.PC
file = frame.File
line = frame.Line
// testing.tRunner is the standard library function that calls
// tests. Subtests are called directly by tRunner, without going through
// the Test/Benchmark/Example function that contains the t.Run calls, so
// with subtests we should break when we hit tRunner, without adding it
// to the list of callers.
if name == "testing.tRunner" {
break
}
// This is a huge edge case, but it will panic if this is the case, see #180
if file == "<autogenerated>" {
break
}
parts := strings.Split(file, "/")
if len(parts) > 1 {
filename := parts[len(parts)-1]
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
f := runtime.FuncForPC(pc)
if f == nil {
break
}
name = f.Name()
// testing.tRunner is the standard library function that calls
// tests. Subtests are called directly by tRunner, without going through
// the Test/Benchmark/Example function that contains the t.Run calls, so
// with subtests we should break when we hit tRunner, without adding it
// to the list of callers.
if name == "testing.tRunner" {
break
}
parts := strings.Split(file, "/")
if len(parts) > 1 {
filename := parts[len(parts)-1]
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
}
}
// Drop the package
dotPos := strings.LastIndexByte(name, '.')
name = name[dotPos+1:]
if isTest(name, "Test") ||
isTest(name, "Benchmark") ||
isTest(name, "Example") {
break
}
if !more {
break
}
}
// Drop the package
segments := strings.Split(name, ".")
name = segments[len(segments)-1]
if isTest(name, "Test") ||
isTest(name, "Benchmark") ||
isTest(name, "Example") {
break
}
// Next batch
offset += cap(pcs)
}
return callers
@ -431,17 +455,34 @@ func NotImplements(t TestingT, interfaceObject interface{}, object interface{},
return true
}
func isType(expectedType, object interface{}) bool {
return ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType))
}
// IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
//
// assert.IsType(t, &MyStruct{}, &MyStruct{})
func IsType(t TestingT, expectedType, object interface{}, msgAndArgs ...interface{}) bool {
if isType(expectedType, object) {
return true
}
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, fmt.Sprintf("Object expected to be of type %T, but was %T", expectedType, object), msgAndArgs...)
}
if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...)
// IsNotType asserts that the specified objects are not of the same type.
//
// assert.IsNotType(t, &NotMyStruct{}, &MyStruct{})
func IsNotType(t TestingT, theType, object interface{}, msgAndArgs ...interface{}) bool {
if !isType(theType, object) {
return true
}
return true
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, fmt.Sprintf("Object type expected to be different than %T", theType), msgAndArgs...)
}
// Equal asserts that two objects are equal.
@ -469,7 +510,6 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
}
return true
}
// validateEqualArgs checks whether provided arguments can be safely used in the
@ -496,10 +536,17 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
h.Helper()
}
if !samePointers(expected, actual) {
same, ok := samePointers(expected, actual)
if !ok {
return Fail(t, "Both arguments must be pointers", msgAndArgs...)
}
if !same {
// both are pointers but not the same type & pointing to the same address
return Fail(t, fmt.Sprintf("Not same: \n"+
"expected: %p %#v\n"+
"actual : %p %#v", expected, expected, actual, actual), msgAndArgs...)
"expected: %p %#[1]v\n"+
"actual : %p %#[2]v",
expected, actual), msgAndArgs...)
}
return true
@ -516,29 +563,37 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
h.Helper()
}
if samePointers(expected, actual) {
same, ok := samePointers(expected, actual)
if !ok {
// fails when the arguments are not pointers
return !(Fail(t, "Both arguments must be pointers", msgAndArgs...))
}
if same {
return Fail(t, fmt.Sprintf(
"Expected and actual point to the same object: %p %#v",
expected, expected), msgAndArgs...)
"Expected and actual point to the same object: %p %#[1]v",
expected), msgAndArgs...)
}
return true
}
// samePointers compares two generic interface objects and returns whether
// they point to the same object
func samePointers(first, second interface{}) bool {
// samePointers checks if two generic interface objects are pointers of the same
// type pointing to the same object. It returns two values: same indicating if
// they are the same type and point to the same object, and ok indicating that
// both inputs are pointers.
func samePointers(first, second interface{}) (same bool, ok bool) {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false
return false, false // not both are pointers
}
firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
if firstType != secondType {
return false
return false, true // both are pointers, but of different types
}
// compare pointer addresses
return first == second
return first == second, true
}
// formatUnequalValues takes two values of arbitrary types and returns string
@ -572,8 +627,8 @@ func truncatingFormat(data interface{}) string {
return value
}
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValues asserts that two objects are equal or convertible to the larger
// type and equal.
//
// assert.EqualValues(t, uint32(123), int32(123))
func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
@ -590,7 +645,6 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
}
return true
}
// EqualExportedValues asserts that the types of two objects are equal and their public
@ -615,21 +669,6 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
}
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
}
if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
}
expected = copyExportedFields(expected)
actual = copyExportedFields(actual)
@ -660,7 +699,6 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
}
return Equal(t, expected, actual, msgAndArgs...)
}
// NotNil asserts that the specified object is not nil.
@ -710,37 +748,45 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// isEmpty gets whether the specified object is considered empty or not.
func isEmpty(object interface{}) bool {
// get nil case out of the way
if object == nil {
return true
}
objValue := reflect.ValueOf(object)
switch objValue.Kind() {
// collection types are empty when they have no element
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty
case reflect.Ptr:
if objValue.IsNil() {
return true
}
deref := objValue.Elem().Interface()
return isEmpty(deref)
// for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
}
return isEmptyValue(reflect.ValueOf(object))
}
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// isEmptyValue gets whether the specified reflect.Value is considered empty or not.
func isEmptyValue(objValue reflect.Value) bool {
if objValue.IsZero() {
return true
}
// Special cases of non-zero values that we consider empty
switch objValue.Kind() {
// collection types are empty when they have no element
// Note: array types are empty when they match their zero-initialized state.
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// non-nil pointers are empty if the value they point to is empty
case reflect.Ptr:
return isEmptyValue(objValue.Elem())
}
return false
}
// Empty asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// assert.Empty(t, obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := isEmpty(object)
if !pass {
@ -751,11 +797,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
}
return pass
}
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotEmpty asserts that the specified object is NOT [Empty].
//
// if assert.NotEmpty(t, obj) {
// assert.Equal(t, "two", obj[1])
@ -770,7 +814,6 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
}
return pass
}
// getLen tries to get the length of an object.
@ -814,7 +857,6 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
}
return true
}
// False asserts that the specified value is false.
@ -829,7 +871,6 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
}
return true
}
// NotEqual asserts that the specified values are NOT equal.
@ -852,7 +893,6 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
}
return true
}
// NotEqualValues asserts that two objects are not equal even when converted to the same type
@ -875,7 +915,6 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (true, false) if element was not found.
// return (true, true) if element was found.
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listType := reflect.TypeOf(list)
if listType == nil {
@ -910,7 +949,6 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {
}
}
return true, false
}
// Contains asserts that the specified string, list(array, slice...) or map contains the
@ -933,7 +971,6 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
}
return true
}
// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the
@ -956,14 +993,17 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
}
return true
}
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subset asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
// assert.Subset(t, [1, 2, 3], {1: "one", 2: "two"})
// assert.Subset(t, {"x": 1, "y": 2}, ["x"])
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -978,7 +1018,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
}
subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
@ -1002,6 +1042,13 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
}
subsetList := reflect.ValueOf(subset)
if subsetKind == reflect.Map {
keys := make([]interface{}, subsetList.Len())
for idx, key := range subsetList.MapKeys() {
keys[idx] = key.Interface()
}
subsetList = reflect.ValueOf(keys)
}
for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
@ -1016,12 +1063,15 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true
}
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubset asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
// assert.NotSubset(t, [1, 3, 4], {1: "one", 2: "two"})
// assert.NotSubset(t, {"x": 1, "y": 2}, ["z"])
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -1036,7 +1086,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
}
subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
@ -1060,11 +1110,18 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
}
subsetList := reflect.ValueOf(subset)
if subsetKind == reflect.Map {
keys := make([]interface{}, subsetList.Len())
for idx, key := range subsetList.MapKeys() {
keys[idx] = key.Interface()
}
subsetList = reflect.ValueOf(keys)
}
for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
return Fail(t, fmt.Sprintf("%q could not be applied builtin len()", list), msgAndArgs...)
}
if !found {
return true
@ -1170,6 +1227,39 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri
return msg.String()
}
// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true
//
// assert.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true
func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isEmpty(listA) && isEmpty(listB) {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
if !isList(t, listA, msgAndArgs...) {
return Fail(t, "listA is not a list type", msgAndArgs...)
}
if !isList(t, listB, msgAndArgs...) {
return Fail(t, "listB is not a list type", msgAndArgs...)
}
extraA, extraB := diffLists(listA, listB)
if len(extraA) == 0 && len(extraB) == 0 {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
return true
}
// Condition uses a Comparison to assert a complex condition.
func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@ -1488,6 +1578,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
if err != nil {
return Fail(t, err.Error(), msgAndArgs...)
}
if math.IsNaN(actualEpsilon) {
return Fail(t, "relative error is NaN", msgAndArgs...)
}
if actualEpsilon > epsilon {
return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+
" < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...)
@ -1550,10 +1643,8 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
// Error asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if assert.Error(t, err) {
// assert.Equal(t, expectedError, err)
// }
// actualObj, err := SomeFunction()
// assert.Error(t, err)
func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
if err == nil {
if h, ok := t.(tHelper); ok {
@ -1611,7 +1702,6 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
var r *regexp.Regexp
if rr, ok := rx.(*regexp.Regexp); ok {
r = rr
@ -1619,8 +1709,14 @@ func matchRegexp(rx interface{}, str interface{}) bool {
r = regexp.MustCompile(fmt.Sprint(rx))
}
return (r.FindStringIndex(fmt.Sprint(str)) != nil)
switch v := str.(type) {
case []byte:
return r.Match(v)
case string:
return r.MatchString(v)
default:
return r.MatchString(fmt.Sprint(v))
}
}
// Regexp asserts that a specified regexp matches a string.
@ -1656,7 +1752,6 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf
}
return !match
}
// Zero asserts that i is the zero value for its type.
@ -1767,6 +1862,11 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...)
}
// Shortcut if same bytes
if actual == expected {
return true
}
if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil {
return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...)
}
@ -1785,6 +1885,11 @@ func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid yaml.\nYAML parsing error: '%s'", expected, err.Error()), msgAndArgs...)
}
// Shortcut if same bytes
if actual == expected {
return true
}
if err := yaml.Unmarshal([]byte(actual), &actualYAMLAsInterface); err != nil {
return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid yaml.\nYAML error: '%s'", actual, err.Error()), msgAndArgs...)
}
@ -1872,7 +1977,7 @@ var spewConfigStringerEnabled = spew.ConfigState{
MaxDepth: 10,
}
type tHelper interface {
type tHelper = interface {
Helper()
}
@ -1886,6 +1991,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
}
ch := make(chan bool, 1)
checkCond := func() { ch <- condition() }
timer := time.NewTimer(waitFor)
defer timer.Stop()
@ -1893,35 +1999,47 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
ticker := time.NewTicker(tick)
defer ticker.Stop()
for tick := ticker.C; ; {
var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select {
case <-timer.C:
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
go func() { ch <- condition() }()
case <-tickC:
tickC = nil
go checkCond()
case v := <-ch:
if v {
return true
}
tick = ticker.C
tickC = ticker.C
}
}
}
// CollectT implements the TestingT interface and collects all errors.
type CollectT struct {
// A slice of errors. Non-nil slice denotes a failure.
// If it's non-nil but len(c.errors) == 0, this is also a failure
// obtained by direct c.FailNow() call.
errors []error
}
// Helper is like [testing.T.Helper] but does nothing.
func (CollectT) Helper() {}
// Errorf collects the error.
func (c *CollectT) Errorf(format string, args ...interface{}) {
c.errors = append(c.errors, fmt.Errorf(format, args...))
}
// FailNow panics.
func (*CollectT) FailNow() {
panic("Assertion failed")
// FailNow stops execution by calling runtime.Goexit.
func (c *CollectT) FailNow() {
c.fail()
runtime.Goexit()
}
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
@ -1934,6 +2052,16 @@ func (*CollectT) Copy(TestingT) {
panic("Copy() is deprecated")
}
func (c *CollectT) fail() {
if !c.failed() {
c.errors = []error{} // Make it non-nil to mark a failure.
}
}
func (c *CollectT) failed() bool {
return c.errors != nil
}
// EventuallyWithT asserts that given condition will be met in waitFor time,
// periodically checking target function each tick. In contrast to Eventually,
// it supplies a CollectT to the condition function, so that the condition
@ -1951,14 +2079,22 @@ func (*CollectT) Copy(TestingT) {
// assert.EventuallyWithT(t, func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
var lastFinishedTickErrs []error
ch := make(chan []error, 1)
ch := make(chan *CollectT, 1)
checkCond := func() {
collect := new(CollectT)
defer func() {
ch <- collect
}()
condition(collect)
}
timer := time.NewTimer(waitFor)
defer timer.Stop()
@ -1966,29 +2102,28 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
ticker := time.NewTicker(tick)
defer ticker.Stop()
for tick := ticker.C; ; {
var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select {
case <-timer.C:
for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
go func() {
collect := new(CollectT)
defer func() {
ch <- collect.errors
}()
condition(collect)
}()
case errs := <-ch:
if len(errs) == 0 {
case <-tickC:
tickC = nil
go checkCond()
case collect := <-ch:
if !collect.failed() {
return true
}
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C
lastFinishedTickErrs = collect.errors
tickC = ticker.C
}
}
}
@ -2003,6 +2138,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D
}
ch := make(chan bool, 1)
checkCond := func() { ch <- condition() }
timer := time.NewTimer(waitFor)
defer timer.Stop()
@ -2010,18 +2146,23 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D
ticker := time.NewTicker(tick)
defer ticker.Stop()
for tick := ticker.C; ; {
var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select {
case <-timer.C:
return true
case <-tick:
tick = nil
go func() { ch <- condition() }()
case <-tickC:
tickC = nil
go checkCond()
case v := <-ch:
if v {
return Fail(t, "Condition satisfied", msgAndArgs...)
}
tick = ticker.C
tickC = ticker.C
}
}
}
@ -2039,9 +2180,12 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
var expectedText string
if target != nil {
expectedText = target.Error()
if err == nil {
return Fail(t, fmt.Sprintf("Expected error with %q in chain but got nil.", expectedText), msgAndArgs...)
}
}
chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)
return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
@ -2049,7 +2193,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
), msgAndArgs...)
}
// NotErrorIs asserts that at none of the errors in err's chain matches target.
// NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@ -2064,7 +2208,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}
chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)
return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
@ -2082,24 +2226,70 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{
return true
}
chain := buildErrorChainString(err)
expectedType := reflect.TypeOf(target).Elem().String()
if err == nil {
return Fail(t, fmt.Sprintf("An error is expected but got nil.\n"+
"expected: %s", expectedType), msgAndArgs...)
}
chain := buildErrorChainString(err, true)
return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+
"in chain: %s", target, chain,
"expected: %s\n"+
"in chain: %s", expectedType, chain,
), msgAndArgs...)
}
func buildErrorChainString(err error) string {
// NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !errors.As(err, target) {
return true
}
chain := buildErrorChainString(err, true)
return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %s\n"+
"in chain: %s", reflect.TypeOf(target).Elem().String(), chain,
), msgAndArgs...)
}
func unwrapAll(err error) (errs []error) {
errs = append(errs, err)
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, unwrapAll(err)...)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
}
return
}
func buildErrorChainString(err error, withType bool) string {
if err == nil {
return ""
}
e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
var chain string
errs := unwrapAll(err)
for i := range errs {
if i != 0 {
chain += "\n\t"
}
chain += fmt.Sprintf("%q", errs[i].Error())
if withType {
chain += fmt.Sprintf(" (%T)", errs[i])
}
}
return chain
}

View file

@ -1,5 +1,9 @@
// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system.
//
// # Note
//
// All functions in this package return a bool value indicating whether the assertion has passed.
//
// # Example Usage
//
// The following is a complete example using assert in a standard test function:

View file

@ -138,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str))
if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
Fail(t, fmt.Sprintf("Expected response body for %q to contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...)
}
return contains
@ -158,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str))
if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
Fail(t, fmt.Sprintf("Expected response body for %q to NOT contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...)
}
return !contains

View file

@ -0,0 +1,24 @@
//go:build testify_yaml_custom && !testify_yaml_fail && !testify_yaml_default
// Package yaml is an implementation of YAML functions that calls a pluggable implementation.
//
// This implementation is selected with the testify_yaml_custom build tag.
//
// go test -tags testify_yaml_custom
//
// This implementation can be used at build time to replace the default implementation
// to avoid linking with [gopkg.in/yaml.v3].
//
// In your test package:
//
// import assertYaml "github.com/stretchr/testify/assert/yaml"
//
// func init() {
// assertYaml.Unmarshal = func (in []byte, out interface{}) error {
// // ...
// return nil
// }
// }
package yaml
var Unmarshal func(in []byte, out interface{}) error

View file

@ -0,0 +1,36 @@
//go:build !testify_yaml_fail && !testify_yaml_custom
// Package yaml is just an indirection to handle YAML deserialization.
//
// This package is just an indirection that allows the builder to override the
// indirection with an alternative implementation of this package that uses
// another implementation of YAML deserialization. This allows to not either not
// use YAML deserialization at all, or to use another implementation than
// [gopkg.in/yaml.v3] (for example for license compatibility reasons, see [PR #1120]).
//
// Alternative implementations are selected using build tags:
//
// - testify_yaml_fail: [Unmarshal] always fails with an error
// - testify_yaml_custom: [Unmarshal] is a variable. Caller must initialize it
// before calling any of [github.com/stretchr/testify/assert.YAMLEq] or
// [github.com/stretchr/testify/assert.YAMLEqf].
//
// Usage:
//
// go test -tags testify_yaml_fail
//
// You can check with "go list" which implementation is linked:
//
// go list -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
// go list -tags testify_yaml_fail -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
// go list -tags testify_yaml_custom -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
//
// [PR #1120]: https://github.com/stretchr/testify/pull/1120
package yaml
import goyaml "gopkg.in/yaml.v3"
// Unmarshal is just a wrapper of [gopkg.in/yaml.v3.Unmarshal].
func Unmarshal(in []byte, out interface{}) error {
return goyaml.Unmarshal(in, out)
}

View file

@ -0,0 +1,17 @@
//go:build testify_yaml_fail && !testify_yaml_custom && !testify_yaml_default
// Package yaml is an implementation of YAML functions that always fail.
//
// This implementation can be used at build time to replace the default implementation
// to avoid linking with [gopkg.in/yaml.v3]:
//
// go test -tags testify_yaml_fail
package yaml
import "errors"
var errNotImplemented = errors.New("YAML functions are not available (see https://pkg.go.dev/github.com/stretchr/testify/assert/yaml)")
func Unmarshal([]byte, interface{}) error {
return errNotImplemented
}

View file

@ -23,6 +23,8 @@
//
// The `require` package have same global functions as in the `assert` package,
// but instead of returning a boolean result they call `t.FailNow()`.
// A consequence of this is that it must be called from the goroutine running
// the test function, not from other goroutines created during the test.
//
// Every assertion function also takes an optional string message as the final argument,
// allowing custom error messages to be appended to the message the assertion method outputs.

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
{{.Comment}}
{{ replace .Comment "assert." "require."}}
func {{.DocInfo.Name}}(t TestingT, {{.Params}}) {
if h, ok := t.(tHelper); ok { h.Helper() }
if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }

View file

@ -93,10 +93,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st
ElementsMatchf(a.t, listA, listB, msg, args...)
}
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// Empty asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// a.Empty(obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -104,10 +113,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) {
Empty(a.t, object, msgAndArgs...)
}
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
// Emptyf asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
//
// a.Emptyf(obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -187,8 +205,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
EqualExportedValuesf(a.t, expected, actual, msg, args...)
}
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValues asserts that two objects are equal or convertible to the larger
// type and equal.
//
// a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
@ -198,8 +216,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
EqualValues(a.t, expected, actual, msgAndArgs...)
}
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
// EqualValuesf asserts that two objects are equal or convertible to the larger
// type and equal.
//
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
@ -225,10 +243,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string
// Error asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Error(err) {
// assert.Equal(t, expectedError, err)
// }
// actualObj, err := SomeFunction()
// a.Error(err)
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -298,10 +314,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if a.Errorf(err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
// actualObj, err := SomeFunction()
// a.Errorf(err, "error message %s", "formatted")
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -337,7 +351,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti
// a.EventuallyWithT(func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -362,7 +376,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), w
// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithTf(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -869,7 +883,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in
IsNonIncreasingf(a.t, object, msg, args...)
}
// IsNotType asserts that the specified objects are not of the same type.
//
// a.IsNotType(&NotMyStruct{}, &MyStruct{})
func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsNotType(a.t, theType, object, msgAndArgs...)
}
// IsNotTypef asserts that the specified objects are not of the same type.
//
// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsNotTypef(a.t, theType, object, msg, args...)
}
// IsType asserts that the specified objects are of the same type.
//
// a.IsType(&MyStruct{}, &MyStruct{})
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -878,6 +914,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd
}
// IsTypef asserts that the specified objects are of the same type.
//
// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1129,8 +1167,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin
NotContainsf(a.t, s, contains, msg, args...)
}
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true
//
// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true
func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotElementsMatch(a.t, listA, listB, msgAndArgs...)
}
// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotElementsMatchf(a.t, listA, listB, msg, args...)
}
// NotEmpty asserts that the specified object is NOT [Empty].
//
// if a.NotEmpty(obj) {
// assert.Equal(t, "two", obj[1])
@ -1142,8 +1213,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) {
NotEmpty(a.t, object, msgAndArgs...)
}
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
// NotEmptyf asserts that the specified object is NOT [Empty].
//
// if a.NotEmptyf(obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
@ -1201,7 +1271,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
NotEqualf(a.t, expected, actual, msg, args...)
}
// NotErrorIs asserts that at none of the errors in err's chain matches target.
// NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotErrorAs(a.t, err, target, msgAndArgs...)
}
// NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotErrorAsf(a.t, err, target, msg, args...)
}
// NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
@ -1210,7 +1298,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface
NotErrorIs(a.t, err, target, msgAndArgs...)
}
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
// NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
@ -1327,12 +1415,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
NotSamef(a.t, expected, actual, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubset asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
// a.NotSubset([1, 3, 4], {1: "one", 2: "two"})
// a.NotSubset({"x": 1, "y": 2}, ["z"])
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1340,12 +1431,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
NotSubset(a.t, list, subset, msgAndArgs...)
}
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// elements given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1505,11 +1599,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
Samef(a.t, expected, actual, msg, args...)
}
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subset asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
// a.Subset([1, 2, 3], {1: "one", 2: "two"})
// a.Subset({"x": 1, "y": 2}, ["x"])
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1517,11 +1615,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
Subset(a.t, list, subset, msgAndArgs...)
}
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
// Subsetf asserts that the list (array, slice, or map) contains all elements
// given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
//
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()

View file

@ -6,7 +6,7 @@ type TestingT interface {
FailNow()
}
type tHelper interface {
type tHelper = interface {
Helper()
}

4
vendor/golang.org/x/crypto/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

View file

@ -31,12 +31,11 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"net/http"
"strings"
"sync"
@ -353,6 +352,10 @@ func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization
if _, err := c.Discover(ctx); err != nil {
return nil, err
}
if c.dir.AuthzURL == "" {
// Pre-Authorization is unsupported
return nil, errPreAuthorizationNotSupported
}
type authzID struct {
Type string `json:"type"`
@ -467,7 +470,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
// while waiting for a final authorization status.
d := retryAfter(res.Header.Get("Retry-After"))
if d == 0 {
// Given that the fastest challenges TLS-SNI and HTTP-01
// Given that the fastest challenges TLS-ALPN and HTTP-01
// require a CA to make at least 1 network round trip
// and most likely persist a challenge state,
// this default delay seems reasonable.
@ -514,7 +517,11 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
return nil, err
}
res, err := c.post(ctx, nil, chal.URI, json.RawMessage("{}"), wantStatus(
payload := json.RawMessage("{}")
if len(chal.Payload) != 0 {
payload = chal.Payload
}
res, err := c.post(ctx, nil, chal.URI, payload, wantStatus(
http.StatusOK, // according to the spec
http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md)
))
@ -564,50 +571,28 @@ func (c *Client) HTTP01ChallengePath(token string) string {
}
// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response.
// Always returns an error.
//
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec.
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, "", err
}
b := sha256.Sum256([]byte(ka))
h := hex.EncodeToString(b[:])
name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:])
cert, err = tlsChallengeCert([]string{name}, opt)
if err != nil {
return tls.Certificate{}, "", err
}
return cert, name, nil
// Deprecated: This challenge type was only present in pre-standardized ACME
// protocol drafts and is insecure for use in shared hosting environments.
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (tls.Certificate, string, error) {
return tls.Certificate{}, "", errPreRFC
}
// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response.
// Always returns an error.
//
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec.
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
b := sha256.Sum256([]byte(token))
h := hex.EncodeToString(b[:])
sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:])
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, "", err
}
b = sha256.Sum256([]byte(ka))
h = hex.EncodeToString(b[:])
sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:])
cert, err = tlsChallengeCert([]string{sanA, sanB}, opt)
if err != nil {
return tls.Certificate{}, "", err
}
return cert, sanA, nil
// Deprecated: This challenge type was only present in pre-standardized ACME
// protocol drafts and is insecure for use in shared hosting environments.
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (tls.Certificate, string, error) {
return tls.Certificate{}, "", errPreRFC
}
// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response.
// Servers can present the certificate to validate the challenge and prove control
// over a domain name. For more details on TLS-ALPN-01 see
// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3
// over an identifier (either a DNS name or the textual form of an IPv4 or IPv6
// address). For more details on TLS-ALPN-01 see
// https://www.rfc-editor.org/rfc/rfc8737 and https://www.rfc-editor.org/rfc/rfc8738
//
// The token argument is a Challenge.Token value.
// If a WithKey option is provided, its private part signs the returned cert,
@ -615,9 +600,13 @@ func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tl
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
//
// The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol
// the server name in the TLS ClientHello matches the identifier, and the special acme-tls/1 ALPN protocol
// has been specified.
func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) {
//
// Validation requests for IP address identifiers will use the reverse DNS form in the server name
// in the TLS ClientHello since the SNI extension is not supported for IP addresses.
// See RFC 8738 Section 6 for more information.
func (c *Client) TLSALPN01ChallengeCert(token, identifier string, opt ...CertOption) (cert tls.Certificate, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, err
@ -647,7 +636,7 @@ func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption)
}
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension)
newOpt = append(newOpt, WithTemplate(tmpl))
return tlsChallengeCert([]string{domain}, newOpt)
return tlsChallengeCert(identifier, newOpt)
}
// popNonce returns a nonce value previously stored with c.addNonce
@ -701,7 +690,7 @@ func (c *Client) addNonce(h http.Header) {
}
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
r, err := http.NewRequest("HEAD", url, nil)
r, err := http.NewRequestWithContext(ctx, "HEAD", url, nil)
if err != nil {
return "", err
}
@ -765,11 +754,15 @@ func defaultTLSChallengeCertTemplate() *x509.Certificate {
}
}
// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges
// with the given SANs and auto-generated public/private key pair.
// The Subject Common Name is set to the first SAN to aid debugging.
// tlsChallengeCert creates a temporary certificate for TLS-ALPN challenges
// for the given identifier, using an auto-generated public/private key pair.
//
// If the provided identifier is a domain name, it will be used as a DNS type SAN and for the
// subject common name. If the provided identifier is an IP address it will be used as an IP type
// SAN.
//
// To create a cert with a custom key pair, specify WithKey option.
func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
func tlsChallengeCert(identifier string, opt []CertOption) (tls.Certificate, error) {
var key crypto.Signer
tmpl := defaultTLSChallengeCertTemplate()
for _, o := range opt {
@ -793,9 +786,12 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
return tls.Certificate{}, err
}
}
tmpl.DNSNames = san
if len(san) > 0 {
tmpl.Subject.CommonName = san[0]
if ip := net.ParseIP(identifier); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{identifier}
tmpl.Subject.CommonName = identifier
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
@ -808,11 +804,5 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
}, nil
}
// encodePEM returns b encoded as PEM with block of type typ.
func encodePEM(typ string, b []byte) []byte {
pb := &pem.Block{Type: typ, Bytes: b}
return pem.EncodeToMemory(pb)
}
// timeNow is time.Now, except in tests which can mess with it.
var timeNow = time.Now

View file

@ -134,7 +134,8 @@ type Manager struct {
// RenewBefore optionally specifies how early certificates should
// be renewed before they expire.
//
// If zero, they're renewed 30 days before expiration.
// If zero, they're renewed at the lesser of 30 days or
// 1/3 of the certificate lifetime.
RenewBefore time.Duration
// Client is used to perform low-level operations, such as account registration
@ -292,6 +293,10 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
}
// regular domain
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
ck := certKey{
domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114
isRSA: !supportsECDSA(hello),
@ -305,9 +310,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
}
// first-time
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
cert, err = m.createCert(ctx, ck)
if err != nil {
return nil, err
@ -463,7 +465,7 @@ func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error
leaf: cert.Leaf,
}
m.state[ck] = s
m.startRenew(ck, s.key, s.leaf.NotAfter)
m.startRenew(ck, s.key, s.leaf.NotBefore, s.leaf.NotAfter)
return cert, nil
}
@ -609,7 +611,7 @@ func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate,
}
state.cert = der
state.leaf = leaf
m.startRenew(ck, state.key, state.leaf.NotAfter)
m.startRenew(ck, state.key, state.leaf.NotBefore, state.leaf.NotAfter)
return state.tlscert()
}
@ -907,7 +909,7 @@ func httpTokenCacheKey(tokenPath string) string {
//
// The key argument is a certificate private key.
// The exp argument is the cert expiration time (NotAfter).
func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) {
func (m *Manager) startRenew(ck certKey, key crypto.Signer, notBefore, notAfter time.Time) {
m.renewalMu.Lock()
defer m.renewalMu.Unlock()
if m.renewal[ck] != nil {
@ -919,7 +921,7 @@ func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) {
}
dr := &domainRenewal{m: m, ck: ck, key: key}
m.renewal[ck] = dr
dr.start(exp)
dr.start(notBefore, notAfter)
}
// stopRenew stops all currently running cert renewal timers.
@ -1027,13 +1029,6 @@ func (m *Manager) hostPolicy() HostPolicy {
return defaultHostPolicy
}
func (m *Manager) renewBefore() time.Duration {
if m.RenewBefore > renewJitter {
return m.RenewBefore
}
return 720 * time.Hour // 30 days
}
func (m *Manager) now() time.Time {
if m.nowFunc != nil {
return m.nowFunc()

View file

@ -10,7 +10,6 @@ import (
"net"
"os"
"path/filepath"
"runtime"
"time"
)
@ -124,32 +123,13 @@ func (ln *listener) Close() error {
return ln.tcpListener.Close()
}
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string {
const base = "golang-autocert"
switch runtime.GOOS {
case "darwin":
return filepath.Join(homeDir(), "Library", "Caches", base)
case "windows":
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
cache, err := os.UserCacheDir()
if err != nil {
// Fall back to the root directory.
cache = "/.cache"
}
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base)
}
return filepath.Join(homeDir(), ".cache", base)
return filepath.Join(cache, base)
}

View file

@ -11,9 +11,6 @@ import (
"time"
)
// renewJitter is the maximum deviation from Manager.RenewBefore.
const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
type domainRenewal struct {
@ -30,13 +27,13 @@ type domainRenewal struct {
// defined by the certificate expiration time exp.
//
// If the timer is already started, calling start is a noop.
func (dr *domainRenewal) start(exp time.Time) {
func (dr *domainRenewal) start(notBefore, notAfter time.Time) {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer != nil {
return
}
dr.timer = time.AfterFunc(dr.next(exp), dr.renew)
dr.timer = time.AfterFunc(dr.next(notBefore, notAfter), dr.renew)
}
// stop stops the cert renewal timer and waits for any in-flight calls to renew
@ -79,7 +76,7 @@ func (dr *domainRenewal) renew() {
// TODO: rotate dr.key at some point?
next, err := dr.do(ctx)
if err != nil {
next = renewJitter / 2
next = time.Hour / 2
next += time.Duration(pseudoRand.int63n(int64(next)))
}
testDidRenewLoop(next, err)
@ -107,8 +104,8 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter {
next := dr.next(tlscert.Leaf.NotBefore, tlscert.Leaf.NotAfter)
if next > 0 {
signer, ok := tlscert.PrivateKey.(crypto.Signer)
if ok {
state := &certState{
@ -139,18 +136,23 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
return 0, err
}
dr.updateState(state)
return dr.next(leaf.NotAfter), nil
return dr.next(leaf.NotBefore, leaf.NotAfter), nil
}
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(dr.m.now()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)
if d < 0 {
return 0
// next returns the wait time before the next renewal should start.
// If manager.RenewBefore is set, it uses that capped at 30 days,
// otherwise it uses a default of 1/3 of the cert lifetime.
// It builds in a jitter of 10% of the renew threshold, capped at 1 hour.
func (dr *domainRenewal) next(notBefore, notAfter time.Time) time.Duration {
threshold := min(notAfter.Sub(notBefore)/3, 30*24*time.Hour)
if dr.m.RenewBefore > 0 {
threshold = min(dr.m.RenewBefore, 30*24*time.Hour)
}
return d
maxJitter := min(threshold/10, time.Hour)
jitter := pseudoRand.int63n(int64(maxJitter))
renewAt := notAfter.Add(-(threshold - time.Duration(jitter)))
renewWait := renewAt.Sub(dr.m.now())
return max(0, renewWait)
}
var testDidRenewLoop = func(next time.Duration, err error) {}

View file

@ -15,6 +15,7 @@ import (
"io"
"math/big"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
@ -65,7 +66,7 @@ func (c *Client) retryTimer() *retryTimer {
// The n argument is always bounded between 1 and 30.
// The returned value is always greater than 0.
func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
const max = 10 * time.Second
const maxVal = 10 * time.Second
var jitter time.Duration
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
// Set the minimum to 1ms to avoid a case where
@ -85,10 +86,7 @@ func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
n = 30
}
d := time.Duration(1<<uint(n-1))*time.Second + jitter
if d > max {
return max
}
return d
return min(d, maxVal)
}
// retryAfter parses a Retry-After HTTP header value,
@ -130,7 +128,7 @@ func wantStatus(codes ...int) resOkay {
func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
retry := c.retryTimer()
for {
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
@ -230,7 +228,7 @@ func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string,
if err != nil {
return nil, nil, err
}
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b))
if err != nil {
return nil, nil, err
}
@ -271,9 +269,27 @@ func (c *Client) httpClient() *http.Client {
}
// packageVersion is the version of the module that contains this package, for
// sending as part of the User-Agent header. It's set in version_go112.go.
// sending as part of the User-Agent header.
var packageVersion string
func init() {
// Set packageVersion if the binary was built in modules mode and x/crypto
// was not replaced with a different module.
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
for _, m := range info.Deps {
if m.Path != "golang.org/x/crypto" {
continue
}
if m.Replace == nil {
packageVersion = m.Version
}
break
}
}
// userAgent returns the User-Agent header value. It includes the package name,
// the module version (if available), and the c.UserAgent value (if set).
func (c *Client) userAgent() string {

View file

@ -92,7 +92,7 @@ func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, ur
if err != nil {
return nil, err
}
phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON))
phead := base64.RawURLEncoding.EncodeToString(phJSON)
var payload string
if val, ok := claimset.(string); ok {
payload = val

View file

@ -232,7 +232,7 @@ func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderO
return responseOrder(res)
}
// GetOrder retrives an order identified by the given URL.
// GetOrder retrieves an order identified by the given URL.
// For orders created with AuthorizeOrder, the url value is Order.URI.
//
// If a caller needs to poll an order until its status is final,
@ -272,7 +272,7 @@ func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) {
case err != nil:
// Skip and retry.
case o.Status == StatusInvalid:
return nil, &OrderError{OrderURL: o.URI, Status: o.Status}
return nil, &OrderError{OrderURL: o.URI, Status: o.Status, Problem: o.Error}
case o.Status == StatusReady || o.Status == StatusValid:
return o, nil
}
@ -369,7 +369,7 @@ func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bu
}
// The only acceptable status post finalize and WaitOrder is "valid".
if o.Status != StatusValid {
return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status}
return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status, Problem: o.Error}
}
crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle)
return crt, o.CertURL, err

View file

@ -7,6 +7,7 @@ package acme
import (
"crypto"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -55,6 +56,10 @@ var (
// ErrNoAccount indicates that the Client's key has not been registered with the CA.
ErrNoAccount = errors.New("acme: account does not exist")
// errPreAuthorizationNotSupported indicates that the server does not
// support pre-authorization of identifiers.
errPreAuthorizationNotSupported = errors.New("acme: pre-authorization is not supported")
)
// A Subproblem describes an ACME subproblem as reported in an Error.
@ -149,13 +154,16 @@ func (a *AuthorizationError) Error() string {
// OrderError is returned from Client's order related methods.
// It indicates the order is unusable and the clients should start over with
// AuthorizeOrder.
// AuthorizeOrder. A Problem description may be provided with details on
// what caused the order to become unusable.
//
// The clients can still fetch the order object from CA using GetOrder
// to inspect its state.
type OrderError struct {
OrderURL string
Status string
// Problem is the error that occurred while processing the order.
Problem *Error
}
func (oe *OrderError) Error() string {
@ -288,7 +296,7 @@ type Directory struct {
// KeyChangeURL allows to perform account key rollover flow.
KeyChangeURL string
// Term is a URI identifying the current terms of service.
// Terms is a URI identifying the current terms of service.
Terms string
// Website is an HTTP or HTTPS URL locating a website
@ -527,6 +535,16 @@ type Challenge struct {
// when this challenge was used.
// The type of a non-nil value is *Error.
Error error
// Payload is the JSON-formatted payload that the client sends
// to the server to indicate it is ready to respond to the challenge.
// When unset, it defaults to an empty JSON object: {}.
// For most challenges, the client must not set Payload,
// see https://tools.ietf.org/html/rfc8555#section-7.5.1.
// Payload is used only for newer challenges (such as "device-attest-01")
// where the client must send additional data for the server to validate
// the challenge.
Payload json.RawMessage
}
// wireChallenge is ACME JSON challenge representation.
@ -604,7 +622,7 @@ func (*certOptKey) privateCertOpt() {}
//
// In TLS ChallengeCert methods, the template is also used as parent,
// resulting in a self-signed certificate.
// The DNSNames field of t is always overwritten for tls-sni challenge certs.
// The DNSNames or IPAddresses fields of t are always overwritten for tls-alpn challenge certs.
func WithTemplate(t *x509.Certificate) CertOption {
return (*certOptTemplate)(t)
}

View file

@ -1,27 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.12
package acme
import "runtime/debug"
func init() {
// Set packageVersion if the binary was built in modules mode and x/crypto
// was not replaced with a different module.
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
for _, m := range info.Deps {
if m.Path != "golang.org/x/crypto" {
continue
}
if m.Replace == nil {
packageVersion = m.Version
}
break
}
}

View file

@ -6,7 +6,7 @@
// Argon2 was selected as the winner of the Password Hashing Competition and can
// be used to derive cryptographic keys from passwords.
//
// For a detailed specification of Argon2 see [1].
// For a detailed specification of Argon2 see [argon2-specs.pdf].
//
// If you aren't sure which function you need, use Argon2id (IDKey) and
// the parameter recommendations for your scenario.
@ -17,7 +17,7 @@
// It uses data-independent memory access, which is preferred for password
// hashing and password-based key derivation. Argon2i requires more passes over
// memory than Argon2id to protect from trade-off attacks. The recommended
// parameters (taken from [2]) for non-interactive operations are time=3 and to
// parameters (taken from [RFC 9106 Section 7.3]) for non-interactive operations are time=3 and to
// use the maximum available memory.
//
// # Argon2id
@ -27,11 +27,11 @@
// half of the first iteration over the memory and data-dependent memory access
// for the rest. Argon2id is side-channel resistant and provides better brute-
// force cost savings due to time-memory tradeoffs than Argon2i. The recommended
// parameters for non-interactive operations (taken from [2]) are time=1 and to
// parameters for non-interactive operations (taken from [RFC 9106 Section 7.3]) are time=1 and to
// use the maximum available memory.
//
// [1] https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
// [2] https://tools.ietf.org/html/draft-irtf-cfrg-argon2-03#section-9.3
// [argon2-specs.pdf]: https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
// [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
package argon2
import (
@ -59,7 +59,7 @@ const (
//
// key := argon2.Key([]byte("some password"), salt, 3, 32*1024, 4, 32)
//
// The draft RFC recommends[2] time=3, and memory=32*1024 is a sensible number.
// [RFC 9106 Section 7.3] recommends time=3, and memory=32*1024 as a sensible number.
// If using that amount of memory (32 MB) is not possible in some contexts then
// the time parameter can be increased to compensate.
//
@ -69,6 +69,8 @@ const (
// adjusted to the number of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt.
//
// [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen)
}
@ -83,7 +85,7 @@ func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint3
//
// key := argon2.IDKey([]byte("some password"), salt, 1, 64*1024, 4, 32)
//
// The draft RFC recommends[2] time=1, and memory=64*1024 is a sensible number.
// [RFC 9106 Section 7.3] recommends time=1, and memory=64*1024 as a sensible number.
// If using that amount of memory (64 MB) is not possible in some contexts then
// the time parameter can be increased to compensate.
//
@ -93,6 +95,8 @@ func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint3
// adjusted to the numbers of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt.
//
// [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
func IDKey(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2id, password, salt, nil, nil, time, memory, threads, keyLen)
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,8 @@ import (
// XOF defines the interface to hash functions that
// support arbitrary-length output.
//
// New callers should prefer the standard library [hash.XOF].
type XOF interface {
// Write absorbs more data into the hash's state. It panics if called
// after Read.
@ -47,6 +49,8 @@ const maxOutputLength = (1 << 32) * 64
//
// A non-nil key turns the hash into a MAC. The key must between
// zero and 32 bytes long.
//
// The result can be safely interface-upgraded to [hash.XOF].
func NewXOF(size uint32, key []byte) (XOF, error) {
if len(key) > Size {
return nil, errKeySize
@ -93,6 +97,10 @@ func (x *xof) Clone() XOF {
return &clone
}
func (x *xof) BlockSize() int {
return x.d.BlockSize()
}
func (x *xof) Reset() {
x.cfg[0] = byte(Size)
binary.LittleEndian.PutUint32(x.cfg[4:], uint32(Size)) // leaf length

11
vendor/golang.org/x/crypto/blake2b/go125.go generated vendored Normal file
View file

@ -0,0 +1,11 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.25
package blake2b
import "hash"
var _ hash.XOF = (*xof)(nil)

View file

@ -234,7 +234,7 @@ func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) {
// Identifiers with the low five bits set indicate high-tag-number format
// (two or more octets), which we don't support.
if tag&0x1f == 0x1f {
b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag)
b.err = fmt.Errorf("cryptobyte: high-tag number identifier octets not supported: 0x%x", tag)
return
}
b.AddUint8(uint8(tag))

View file

@ -4,7 +4,7 @@
// Package asn1 contains supporting types for parsing and building ASN.1
// messages with the cryptobyte package.
package asn1 // import "golang.org/x/crypto/cryptobyte/asn1"
package asn1
// Tag represents an ASN.1 identifier octet, consisting of a tag number
// (indicating a type) and class (such as context-specific or constructed).

View file

@ -15,7 +15,7 @@
//
// See the documentation and examples for the Builder and String types to get
// started.
package cryptobyte // import "golang.org/x/crypto/cryptobyte"
package cryptobyte
// String represents a string of bytes. It provides methods for parsing
// fixed-length and length-prefixed values from it.

View file

@ -5,7 +5,7 @@
// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses
// are signed messages attesting to the validity of a certificate for a small
// period of time. This is used to manage revocation for X.509 certificates.
package ocsp // import "golang.org/x/crypto/ocsp"
package ocsp
import (
"crypto"

View file

@ -16,7 +16,7 @@ Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To
choose, you can pass the `New` functions from the different SHA packages to
pbkdf2.Key.
*/
package pbkdf2 // import "golang.org/x/crypto/pbkdf2"
package pbkdf2
import (
"crypto/hmac"

View file

@ -1,62 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha3 implements the SHA-3 fixed-output-length hash functions and
// the SHAKE variable-output-length hash functions defined by FIPS-202.
//
// Both types of hash function use the "sponge" construction and the Keccak
// permutation. For a detailed specification see http://keccak.noekeon.org/
//
// # Guidance
//
// If you aren't sure what function you need, use SHAKE256 with at least 64
// bytes of output. The SHAKE instances are faster than the SHA3 instances;
// the latter have to allocate memory to conform to the hash.Hash interface.
//
// If you need a secret-key MAC (message authentication code), prepend the
// secret key to the input, hash with SHAKE256 and read at least 32 bytes of
// output.
//
// # Security strengths
//
// The SHA3-x (x equals 224, 256, 384, or 512) functions have a security
// strength against preimage attacks of x bits. Since they only produce "x"
// bits of output, their collision-resistance is only "x/2" bits.
//
// The SHAKE-256 and -128 functions have a generic security strength of 256 and
// 128 bits against all attacks, provided that at least 2x bits of their output
// is used. Requesting more than 64 or 32 bytes of output, respectively, does
// not increase the collision-resistance of the SHAKE functions.
//
// # The sponge construction
//
// A sponge builds a pseudo-random function from a public pseudo-random
// permutation, by applying the permutation to a state of "rate + capacity"
// bytes, but hiding "capacity" of the bytes.
//
// A sponge starts out with a zero state. To hash an input using a sponge, up
// to "rate" bytes of the input are XORed into the sponge's state. The sponge
// is then "full" and the permutation is applied to "empty" it. This process is
// repeated until all the input has been "absorbed". The input is then padded.
// The digest is "squeezed" from the sponge in the same way, except that output
// is copied out instead of input being XORed in.
//
// A sponge is parameterized by its generic security strength, which is equal
// to half its capacity; capacity + rate is equal to the permutation's width.
// Since the KeccakF-1600 permutation is 1600 bits (200 bytes) wide, this means
// that the security strength of a sponge instance is equal to (1600 - bitrate) / 2.
//
// # Recommendations
//
// The SHAKE functions are recommended for most new uses. They can produce
// output of arbitrary length. SHAKE256, with an output length of at least
// 64 bytes, provides 256-bit security against all attacks. The Keccak team
// recommends it for most applications upgrading from SHA2-512. (NIST chose a
// much stronger, but much slower, sponge instance for SHA3-512.)
//
// The SHA-3 functions are "drop-in" replacements for the SHA-2 functions.
// They produce output of the same length, with the same security strengths
// against all attacks. This means, in particular, that SHA3-256 only has
// 128-bit collision resistance, because its output length is 32 bytes.
package sha3 // import "golang.org/x/crypto/sha3"

View file

@ -2,96 +2,94 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha3 implements the SHA-3 hash algorithms and the SHAKE extendable
// output functions defined in FIPS 202.
//
// Most of this package is a wrapper around the crypto/sha3 package in the
// standard library. The only exception is the legacy Keccak hash functions.
package sha3
// This file provides functions for creating instances of the SHA-3
// and SHAKE hash functions, as well as utility functions for hashing
// bytes.
import (
"crypto/sha3"
"hash"
)
// New224 creates a new SHA3-224 hash.
// Its generic security strength is 224 bits against preimage attacks,
// and 112 bits against collision attacks.
//
// It is a wrapper for the [sha3.New224] function in the standard library.
//
//go:fix inline
func New224() hash.Hash {
if h := new224Asm(); h != nil {
return h
}
return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
return sha3.New224()
}
// New256 creates a new SHA3-256 hash.
// Its generic security strength is 256 bits against preimage attacks,
// and 128 bits against collision attacks.
//
// It is a wrapper for the [sha3.New256] function in the standard library.
//
//go:fix inline
func New256() hash.Hash {
if h := new256Asm(); h != nil {
return h
}
return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
return sha3.New256()
}
// New384 creates a new SHA3-384 hash.
// Its generic security strength is 384 bits against preimage attacks,
// and 192 bits against collision attacks.
//
// It is a wrapper for the [sha3.New384] function in the standard library.
//
//go:fix inline
func New384() hash.Hash {
if h := new384Asm(); h != nil {
return h
}
return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
return sha3.New384()
}
// New512 creates a new SHA3-512 hash.
// Its generic security strength is 512 bits against preimage attacks,
// and 256 bits against collision attacks.
//
// It is a wrapper for the [sha3.New512] function in the standard library.
//
//go:fix inline
func New512() hash.Hash {
if h := new512Asm(); h != nil {
return h
}
return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
return sha3.New512()
}
// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
// Sum224 returns the SHA3-224 digest of the data.
func Sum224(data []byte) (digest [28]byte) {
h := New224()
h.Write(data)
h.Sum(digest[:0])
return
//
// It is a wrapper for the [sha3.Sum224] function in the standard library.
//
//go:fix inline
func Sum224(data []byte) [28]byte {
return sha3.Sum224(data)
}
// Sum256 returns the SHA3-256 digest of the data.
func Sum256(data []byte) (digest [32]byte) {
h := New256()
h.Write(data)
h.Sum(digest[:0])
return
//
// It is a wrapper for the [sha3.Sum256] function in the standard library.
//
//go:fix inline
func Sum256(data []byte) [32]byte {
return sha3.Sum256(data)
}
// Sum384 returns the SHA3-384 digest of the data.
func Sum384(data []byte) (digest [48]byte) {
h := New384()
h.Write(data)
h.Sum(digest[:0])
return
//
// It is a wrapper for the [sha3.Sum384] function in the standard library.
//
//go:fix inline
func Sum384(data []byte) [48]byte {
return sha3.Sum384(data)
}
// Sum512 returns the SHA3-512 digest of the data.
func Sum512(data []byte) (digest [64]byte) {
h := New512()
h.Write(data)
h.Sum(digest[:0])
return
//
// It is a wrapper for the [sha3.Sum512] function in the standard library.
//
//go:fix inline
func Sum512(data []byte) [64]byte {
return sha3.Sum512(data)
}

View file

@ -1,27 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !gc || purego || !s390x
package sha3
import (
"hash"
)
// new224Asm returns an assembly implementation of SHA3-224 if available,
// otherwise it returns nil.
func new224Asm() hash.Hash { return nil }
// new256Asm returns an assembly implementation of SHA3-256 if available,
// otherwise it returns nil.
func new256Asm() hash.Hash { return nil }
// new384Asm returns an assembly implementation of SHA3-384 if available,
// otherwise it returns nil.
func new384Asm() hash.Hash { return nil }
// new512Asm returns an assembly implementation of SHA3-512 if available,
// otherwise it returns nil.
func new512Asm() hash.Hash { return nil }

View file

@ -1,13 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 && !purego && gc
package sha3
// This function is implemented in keccakf_amd64.s.
//go:noescape
func keccakF1600(a *[25]uint64)

View file

@ -1,390 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 && !purego && gc
// This code was translated into a form compatible with 6a from the public
// domain sources at https://github.com/gvanas/KeccakCodePackage
// Offsets in state
#define _ba (0*8)
#define _be (1*8)
#define _bi (2*8)
#define _bo (3*8)
#define _bu (4*8)
#define _ga (5*8)
#define _ge (6*8)
#define _gi (7*8)
#define _go (8*8)
#define _gu (9*8)
#define _ka (10*8)
#define _ke (11*8)
#define _ki (12*8)
#define _ko (13*8)
#define _ku (14*8)
#define _ma (15*8)
#define _me (16*8)
#define _mi (17*8)
#define _mo (18*8)
#define _mu (19*8)
#define _sa (20*8)
#define _se (21*8)
#define _si (22*8)
#define _so (23*8)
#define _su (24*8)
// Temporary registers
#define rT1 AX
// Round vars
#define rpState DI
#define rpStack SP
#define rDa BX
#define rDe CX
#define rDi DX
#define rDo R8
#define rDu R9
#define rBa R10
#define rBe R11
#define rBi R12
#define rBo R13
#define rBu R14
#define rCa SI
#define rCe BP
#define rCi rBi
#define rCo rBo
#define rCu R15
#define MOVQ_RBI_RCE MOVQ rBi, rCe
#define XORQ_RT1_RCA XORQ rT1, rCa
#define XORQ_RT1_RCE XORQ rT1, rCe
#define XORQ_RBA_RCU XORQ rBa, rCu
#define XORQ_RBE_RCU XORQ rBe, rCu
#define XORQ_RDU_RCU XORQ rDu, rCu
#define XORQ_RDA_RCA XORQ rDa, rCa
#define XORQ_RDE_RCE XORQ rDe, rCe
#define mKeccakRound(iState, oState, rc, B_RBI_RCE, G_RT1_RCA, G_RT1_RCE, G_RBA_RCU, K_RT1_RCA, K_RT1_RCE, K_RBA_RCU, M_RT1_RCA, M_RT1_RCE, M_RBE_RCU, S_RDU_RCU, S_RDA_RCA, S_RDE_RCE) \
/* Prepare round */ \
MOVQ rCe, rDa; \
ROLQ $1, rDa; \
\
MOVQ _bi(iState), rCi; \
XORQ _gi(iState), rDi; \
XORQ rCu, rDa; \
XORQ _ki(iState), rCi; \
XORQ _mi(iState), rDi; \
XORQ rDi, rCi; \
\
MOVQ rCi, rDe; \
ROLQ $1, rDe; \
\
MOVQ _bo(iState), rCo; \
XORQ _go(iState), rDo; \
XORQ rCa, rDe; \
XORQ _ko(iState), rCo; \
XORQ _mo(iState), rDo; \
XORQ rDo, rCo; \
\
MOVQ rCo, rDi; \
ROLQ $1, rDi; \
\
MOVQ rCu, rDo; \
XORQ rCe, rDi; \
ROLQ $1, rDo; \
\
MOVQ rCa, rDu; \
XORQ rCi, rDo; \
ROLQ $1, rDu; \
\
/* Result b */ \
MOVQ _ba(iState), rBa; \
MOVQ _ge(iState), rBe; \
XORQ rCo, rDu; \
MOVQ _ki(iState), rBi; \
MOVQ _mo(iState), rBo; \
MOVQ _su(iState), rBu; \
XORQ rDe, rBe; \
ROLQ $44, rBe; \
XORQ rDi, rBi; \
XORQ rDa, rBa; \
ROLQ $43, rBi; \
\
MOVQ rBe, rCa; \
MOVQ rc, rT1; \
ORQ rBi, rCa; \
XORQ rBa, rT1; \
XORQ rT1, rCa; \
MOVQ rCa, _ba(oState); \
\
XORQ rDu, rBu; \
ROLQ $14, rBu; \
MOVQ rBa, rCu; \
ANDQ rBe, rCu; \
XORQ rBu, rCu; \
MOVQ rCu, _bu(oState); \
\
XORQ rDo, rBo; \
ROLQ $21, rBo; \
MOVQ rBo, rT1; \
ANDQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _bi(oState); \
\
NOTQ rBi; \
ORQ rBa, rBu; \
ORQ rBo, rBi; \
XORQ rBo, rBu; \
XORQ rBe, rBi; \
MOVQ rBu, _bo(oState); \
MOVQ rBi, _be(oState); \
B_RBI_RCE; \
\
/* Result g */ \
MOVQ _gu(iState), rBe; \
XORQ rDu, rBe; \
MOVQ _ka(iState), rBi; \
ROLQ $20, rBe; \
XORQ rDa, rBi; \
ROLQ $3, rBi; \
MOVQ _bo(iState), rBa; \
MOVQ rBe, rT1; \
ORQ rBi, rT1; \
XORQ rDo, rBa; \
MOVQ _me(iState), rBo; \
MOVQ _si(iState), rBu; \
ROLQ $28, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ga(oState); \
G_RT1_RCA; \
\
XORQ rDe, rBo; \
ROLQ $45, rBo; \
MOVQ rBi, rT1; \
ANDQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _ge(oState); \
G_RT1_RCE; \
\
XORQ rDi, rBu; \
ROLQ $61, rBu; \
MOVQ rBu, rT1; \
ORQ rBa, rT1; \
XORQ rBo, rT1; \
MOVQ rT1, _go(oState); \
\
ANDQ rBe, rBa; \
XORQ rBu, rBa; \
MOVQ rBa, _gu(oState); \
NOTQ rBu; \
G_RBA_RCU; \
\
ORQ rBu, rBo; \
XORQ rBi, rBo; \
MOVQ rBo, _gi(oState); \
\
/* Result k */ \
MOVQ _be(iState), rBa; \
MOVQ _gi(iState), rBe; \
MOVQ _ko(iState), rBi; \
MOVQ _mu(iState), rBo; \
MOVQ _sa(iState), rBu; \
XORQ rDi, rBe; \
ROLQ $6, rBe; \
XORQ rDo, rBi; \
ROLQ $25, rBi; \
MOVQ rBe, rT1; \
ORQ rBi, rT1; \
XORQ rDe, rBa; \
ROLQ $1, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ka(oState); \
K_RT1_RCA; \
\
XORQ rDu, rBo; \
ROLQ $8, rBo; \
MOVQ rBi, rT1; \
ANDQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _ke(oState); \
K_RT1_RCE; \
\
XORQ rDa, rBu; \
ROLQ $18, rBu; \
NOTQ rBo; \
MOVQ rBo, rT1; \
ANDQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _ki(oState); \
\
MOVQ rBu, rT1; \
ORQ rBa, rT1; \
XORQ rBo, rT1; \
MOVQ rT1, _ko(oState); \
\
ANDQ rBe, rBa; \
XORQ rBu, rBa; \
MOVQ rBa, _ku(oState); \
K_RBA_RCU; \
\
/* Result m */ \
MOVQ _ga(iState), rBe; \
XORQ rDa, rBe; \
MOVQ _ke(iState), rBi; \
ROLQ $36, rBe; \
XORQ rDe, rBi; \
MOVQ _bu(iState), rBa; \
ROLQ $10, rBi; \
MOVQ rBe, rT1; \
MOVQ _mi(iState), rBo; \
ANDQ rBi, rT1; \
XORQ rDu, rBa; \
MOVQ _so(iState), rBu; \
ROLQ $27, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ma(oState); \
M_RT1_RCA; \
\
XORQ rDi, rBo; \
ROLQ $15, rBo; \
MOVQ rBi, rT1; \
ORQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _me(oState); \
M_RT1_RCE; \
\
XORQ rDo, rBu; \
ROLQ $56, rBu; \
NOTQ rBo; \
MOVQ rBo, rT1; \
ORQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _mi(oState); \
\
ORQ rBa, rBe; \
XORQ rBu, rBe; \
MOVQ rBe, _mu(oState); \
\
ANDQ rBa, rBu; \
XORQ rBo, rBu; \
MOVQ rBu, _mo(oState); \
M_RBE_RCU; \
\
/* Result s */ \
MOVQ _bi(iState), rBa; \
MOVQ _go(iState), rBe; \
MOVQ _ku(iState), rBi; \
XORQ rDi, rBa; \
MOVQ _ma(iState), rBo; \
ROLQ $62, rBa; \
XORQ rDo, rBe; \
MOVQ _se(iState), rBu; \
ROLQ $55, rBe; \
\
XORQ rDu, rBi; \
MOVQ rBa, rDu; \
XORQ rDe, rBu; \
ROLQ $2, rBu; \
ANDQ rBe, rDu; \
XORQ rBu, rDu; \
MOVQ rDu, _su(oState); \
\
ROLQ $39, rBi; \
S_RDU_RCU; \
NOTQ rBe; \
XORQ rDa, rBo; \
MOVQ rBe, rDa; \
ANDQ rBi, rDa; \
XORQ rBa, rDa; \
MOVQ rDa, _sa(oState); \
S_RDA_RCA; \
\
ROLQ $41, rBo; \
MOVQ rBi, rDe; \
ORQ rBo, rDe; \
XORQ rBe, rDe; \
MOVQ rDe, _se(oState); \
S_RDE_RCE; \
\
MOVQ rBo, rDi; \
MOVQ rBu, rDo; \
ANDQ rBu, rDi; \
ORQ rBa, rDo; \
XORQ rBi, rDi; \
XORQ rBo, rDo; \
MOVQ rDi, _si(oState); \
MOVQ rDo, _so(oState) \
// func keccakF1600(a *[25]uint64)
TEXT ·keccakF1600(SB), 0, $200-8
MOVQ a+0(FP), rpState
// Convert the user state into an internal state
NOTQ _be(rpState)
NOTQ _bi(rpState)
NOTQ _go(rpState)
NOTQ _ki(rpState)
NOTQ _mi(rpState)
NOTQ _sa(rpState)
// Execute the KeccakF permutation
MOVQ _ba(rpState), rCa
MOVQ _be(rpState), rCe
MOVQ _bu(rpState), rCu
XORQ _ga(rpState), rCa
XORQ _ge(rpState), rCe
XORQ _gu(rpState), rCu
XORQ _ka(rpState), rCa
XORQ _ke(rpState), rCe
XORQ _ku(rpState), rCu
XORQ _ma(rpState), rCa
XORQ _me(rpState), rCe
XORQ _mu(rpState), rCu
XORQ _sa(rpState), rCa
XORQ _se(rpState), rCe
MOVQ _si(rpState), rDi
MOVQ _so(rpState), rDo
XORQ _su(rpState), rCu
mKeccakRound(rpState, rpStack, $0x0000000000000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000000008082, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x800000000000808a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000080008000, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000808b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000080000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000080008081, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008009, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000008a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000000000088, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x0000000080008009, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x000000008000000a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000008000808b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x800000000000008b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000000008089, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008003, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000000008002, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000000080, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000800a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x800000008000000a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000080008081, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008080, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x0000000080000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000080008008, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP)
// Revert the internal state to the user state
NOTQ _be(rpState)
NOTQ _bi(rpState)
NOTQ _go(rpState)
NOTQ _ki(rpState)
NOTQ _mi(rpState)
NOTQ _sa(rpState)
RET

263
vendor/golang.org/x/crypto/sha3/legacy_hash.go generated vendored Normal file
View file

@ -0,0 +1,263 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// This implementation is only used for NewLegacyKeccak256 and
// NewLegacyKeccak512, which are not implemented by crypto/sha3.
// All other functions in this package are wrappers around crypto/sha3.
import (
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
"unsafe"
"golang.org/x/sys/cpu"
)
const (
dsbyteKeccak = 0b00000001
// rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
// bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
rateK256 = (1600 - 256) / 8
rateK512 = (1600 - 512) / 8
rateK1024 = (1600 - 1024) / 8
)
// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash {
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
}
// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash {
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
}
// spongeDirection indicates the direction bytes are flowing through the sponge.
type spongeDirection int
const (
// spongeAbsorbing indicates that the sponge is absorbing input.
spongeAbsorbing spongeDirection = iota
// spongeSqueezing indicates that the sponge is being squeezed.
spongeSqueezing
)
type state struct {
a [1600 / 8]byte // main state of the hash
// a[n:rate] is the buffer. If absorbing, it's the remaining space to XOR
// into before running the permutation. If squeezing, it's the remaining
// output to produce before running the permutation.
n, rate int
// dsbyte contains the "domain separation" bits and the first bit of
// the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the
// SHA-3 and SHAKE functions by appending bitstrings to the message.
// Using a little-endian bit-ordering convention, these are "01" for SHA-3
// and "1111" for SHAKE, or 00000010b and 00001111b, respectively. Then the
// padding rule from section 5.1 is applied to pad the message to a multiple
// of the rate, which involves adding a "1" bit, zero or more "0" bits, and
// a final "1" bit. We merge the first "1" bit from the padding into dsbyte,
// giving 00000110b (0x06) and 00011111b (0x1f).
// [1] http://csrc.nist.gov/publications/drafts/fips-202/fips_202_draft.pdf
// "Draft FIPS 202: SHA-3 Standard: Permutation-Based Hash and
// Extendable-Output Functions (May 2014)"
dsbyte byte
outputLen int // the default output size in bytes
state spongeDirection // whether the sponge is absorbing or squeezing
}
// BlockSize returns the rate of sponge underlying this hash function.
func (d *state) BlockSize() int { return d.rate }
// Size returns the output size of the hash function in bytes.
func (d *state) Size() int { return d.outputLen }
// Reset clears the internal state by zeroing the sponge state and
// the buffer indexes, and setting Sponge.state to absorbing.
func (d *state) Reset() {
// Zero the permutation's state.
for i := range d.a {
d.a[i] = 0
}
d.state = spongeAbsorbing
d.n = 0
}
func (d *state) clone() *state {
ret := *d
return &ret
}
// permute applies the KeccakF-1600 permutation.
func (d *state) permute() {
var a *[25]uint64
if cpu.IsBigEndian {
a = new([25]uint64)
for i := range a {
a[i] = binary.LittleEndian.Uint64(d.a[i*8:])
}
} else {
a = (*[25]uint64)(unsafe.Pointer(&d.a))
}
keccakF1600(a)
d.n = 0
if cpu.IsBigEndian {
for i := range a {
binary.LittleEndian.PutUint64(d.a[i*8:], a[i])
}
}
}
// pads appends the domain separation bits in dsbyte, applies
// the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute() {
// Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in the sponge because, if it were full,
// permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct.
d.a[d.n] ^= d.dsbyte
// This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte.
d.a[d.rate-1] ^= 0x80
// Apply the permutation
d.permute()
d.state = spongeSqueezing
}
// Write absorbs more data into the hash's state. It panics if any
// output has already been read.
func (d *state) Write(p []byte) (n int, err error) {
if d.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
n = len(p)
for len(p) > 0 {
x := subtle.XORBytes(d.a[d.n:d.rate], d.a[d.n:d.rate], p)
d.n += x
p = p[x:]
// If the sponge is full, apply the permutation.
if d.n == d.rate {
d.permute()
}
}
return
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (d *state) Read(out []byte) (n int, err error) {
// If we're still absorbing, pad and apply the permutation.
if d.state == spongeAbsorbing {
d.padAndPermute()
}
n = len(out)
// Now, do the squeezing.
for len(out) > 0 {
// Apply the permutation if we've squeezed the sponge dry.
if d.n == d.rate {
d.permute()
}
x := copy(out, d.a[d.n:d.rate])
d.n += x
out = out[x:]
}
return
}
// Sum applies padding to the hash state and then squeezes out the desired
// number of output bytes. It panics if any output has already been read.
func (d *state) Sum(in []byte) []byte {
if d.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Make a copy of the original hash so that caller can keep writing
// and summing.
dup := d.clone()
hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation
dup.Read(hash)
return append(in, hash...)
}
const (
magicKeccak = "sha\x0b"
// magic || rate || main state || n || sponge direction
marshaledSize = len(magicKeccak) + 1 + 200 + 1 + 1
)
func (d *state) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *state) AppendBinary(b []byte) ([]byte, error) {
switch d.dsbyte {
case dsbyteKeccak:
b = append(b, magicKeccak...)
default:
panic("unknown dsbyte")
}
// rate is at most 168, and n is at most rate.
b = append(b, byte(d.rate))
b = append(b, d.a[:]...)
b = append(b, byte(d.n), byte(d.state))
return b, nil
}
func (d *state) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize {
return errors.New("sha3: invalid hash state")
}
magic := string(b[:len(magicKeccak)])
b = b[len(magicKeccak):]
switch {
case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
default:
return errors.New("sha3: invalid hash state identifier")
}
rate := int(b[0])
b = b[1:]
if rate != d.rate {
return errors.New("sha3: invalid hash state function")
}
copy(d.a[:], b)
b = b[len(d.a):]
n, state := int(b[0]), spongeDirection(b[1])
if n > d.rate {
return errors.New("sha3: invalid hash state")
}
d.n = n
if state != spongeAbsorbing && state != spongeSqueezing {
return errors.New("sha3: invalid hash state")
}
d.state = state
return nil
}

View file

@ -2,10 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || purego || !gc
package sha3
// This implementation is only used for NewLegacyKeccak256 and
// NewLegacyKeccak512, which are not implemented by crypto/sha3.
// All other functions in this package are wrappers around crypto/sha3.
import "math/bits"
// rc stores the round constants for use in the ι step.

View file

@ -1,18 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.4
package sha3
import (
"crypto"
)
func init() {
crypto.RegisterHash(crypto.SHA3_224, New224)
crypto.RegisterHash(crypto.SHA3_256, New256)
crypto.RegisterHash(crypto.SHA3_384, New384)
crypto.RegisterHash(crypto.SHA3_512, New512)
}

View file

@ -1,197 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// spongeDirection indicates the direction bytes are flowing through the sponge.
type spongeDirection int
const (
// spongeAbsorbing indicates that the sponge is absorbing input.
spongeAbsorbing spongeDirection = iota
// spongeSqueezing indicates that the sponge is being squeezed.
spongeSqueezing
)
const (
// maxRate is the maximum size of the internal buffer. SHAKE-256
// currently needs the largest buffer.
maxRate = 168
)
type state struct {
// Generic sponge components.
a [25]uint64 // main state of the hash
buf []byte // points into storage
rate int // the number of bytes of state to use
// dsbyte contains the "domain separation" bits and the first bit of
// the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the
// SHA-3 and SHAKE functions by appending bitstrings to the message.
// Using a little-endian bit-ordering convention, these are "01" for SHA-3
// and "1111" for SHAKE, or 00000010b and 00001111b, respectively. Then the
// padding rule from section 5.1 is applied to pad the message to a multiple
// of the rate, which involves adding a "1" bit, zero or more "0" bits, and
// a final "1" bit. We merge the first "1" bit from the padding into dsbyte,
// giving 00000110b (0x06) and 00011111b (0x1f).
// [1] http://csrc.nist.gov/publications/drafts/fips-202/fips_202_draft.pdf
// "Draft FIPS 202: SHA-3 Standard: Permutation-Based Hash and
// Extendable-Output Functions (May 2014)"
dsbyte byte
storage storageBuf
// Specific to SHA-3 and SHAKE.
outputLen int // the default output size in bytes
state spongeDirection // whether the sponge is absorbing or squeezing
}
// BlockSize returns the rate of sponge underlying this hash function.
func (d *state) BlockSize() int { return d.rate }
// Size returns the output size of the hash function in bytes.
func (d *state) Size() int { return d.outputLen }
// Reset clears the internal state by zeroing the sponge state and
// the byte buffer, and setting Sponge.state to absorbing.
func (d *state) Reset() {
// Zero the permutation's state.
for i := range d.a {
d.a[i] = 0
}
d.state = spongeAbsorbing
d.buf = d.storage.asBytes()[:0]
}
func (d *state) clone() *state {
ret := *d
if ret.state == spongeAbsorbing {
ret.buf = ret.storage.asBytes()[:len(ret.buf)]
} else {
ret.buf = ret.storage.asBytes()[d.rate-cap(d.buf) : d.rate]
}
return &ret
}
// permute applies the KeccakF-1600 permutation. It handles
// any input-output buffering.
func (d *state) permute() {
switch d.state {
case spongeAbsorbing:
// If we're absorbing, we need to xor the input into the state
// before applying the permutation.
xorIn(d, d.buf)
d.buf = d.storage.asBytes()[:0]
keccakF1600(&d.a)
case spongeSqueezing:
// If we're squeezing, we need to apply the permutation before
// copying more output.
keccakF1600(&d.a)
d.buf = d.storage.asBytes()[:d.rate]
copyOut(d, d.buf)
}
}
// pads appends the domain separation bits in dsbyte, applies
// the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute(dsbyte byte) {
if d.buf == nil {
d.buf = d.storage.asBytes()[:0]
}
// Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in d.buf because, if it were full,
// permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct.
d.buf = append(d.buf, dsbyte)
zerosStart := len(d.buf)
d.buf = d.storage.asBytes()[:d.rate]
for i := zerosStart; i < d.rate; i++ {
d.buf[i] = 0
}
// This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte.
d.buf[d.rate-1] ^= 0x80
// Apply the permutation
d.permute()
d.state = spongeSqueezing
d.buf = d.storage.asBytes()[:d.rate]
copyOut(d, d.buf)
}
// Write absorbs more data into the hash's state. It panics if any
// output has already been read.
func (d *state) Write(p []byte) (written int, err error) {
if d.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
if d.buf == nil {
d.buf = d.storage.asBytes()[:0]
}
written = len(p)
for len(p) > 0 {
if len(d.buf) == 0 && len(p) >= d.rate {
// The fast path; absorb a full "rate" bytes of input and apply the permutation.
xorIn(d, p[:d.rate])
p = p[d.rate:]
keccakF1600(&d.a)
} else {
// The slow path; buffer the input until we can fill the sponge, and then xor it in.
todo := d.rate - len(d.buf)
if todo > len(p) {
todo = len(p)
}
d.buf = append(d.buf, p[:todo]...)
p = p[todo:]
// If the sponge is full, apply the permutation.
if len(d.buf) == d.rate {
d.permute()
}
}
}
return
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (d *state) Read(out []byte) (n int, err error) {
// If we're still absorbing, pad and apply the permutation.
if d.state == spongeAbsorbing {
d.padAndPermute(d.dsbyte)
}
n = len(out)
// Now, do the squeezing.
for len(out) > 0 {
n := copy(out, d.buf)
d.buf = d.buf[n:]
out = out[n:]
// Apply the permutation if we've squeezed the sponge dry.
if len(d.buf) == 0 {
d.permute()
}
}
return
}
// Sum applies padding to the hash state and then squeezes out the desired
// number of output bytes. It panics if any output has already been read.
func (d *state) Sum(in []byte) []byte {
if d.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Make a copy of the original hash so that caller can keep writing
// and summing.
dup := d.clone()
hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation
dup.Read(hash)
return append(in, hash...)
}

View file

@ -1,303 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package sha3
// This file contains code for using the 'compute intermediate
// message digest' (KIMD) and 'compute last message digest' (KLMD)
// instructions to compute SHA-3 and SHAKE hashes on IBM Z.
import (
"hash"
"golang.org/x/sys/cpu"
)
// codes represent 7-bit KIMD/KLMD function codes as defined in
// the Principles of Operation.
type code uint64
const (
// function codes for KIMD/KLMD
sha3_224 code = 32
sha3_256 = 33
sha3_384 = 34
sha3_512 = 35
shake_128 = 36
shake_256 = 37
nopad = 0x100
)
// kimd is a wrapper for the 'compute intermediate message digest' instruction.
// src must be a multiple of the rate for the given function code.
//
//go:noescape
func kimd(function code, chain *[200]byte, src []byte)
// klmd is a wrapper for the 'compute last message digest' instruction.
// src padding is handled by the instruction.
//
//go:noescape
func klmd(function code, chain *[200]byte, dst, src []byte)
type asmState struct {
a [200]byte // 1600 bit state
buf []byte // care must be taken to ensure cap(buf) is a multiple of rate
rate int // equivalent to block size
storage [3072]byte // underlying storage for buf
outputLen int // output length for full security
function code // KIMD/KLMD function code
state spongeDirection // whether the sponge is absorbing or squeezing
}
func newAsmState(function code) *asmState {
var s asmState
s.function = function
switch function {
case sha3_224:
s.rate = 144
s.outputLen = 28
case sha3_256:
s.rate = 136
s.outputLen = 32
case sha3_384:
s.rate = 104
s.outputLen = 48
case sha3_512:
s.rate = 72
s.outputLen = 64
case shake_128:
s.rate = 168
s.outputLen = 32
case shake_256:
s.rate = 136
s.outputLen = 64
default:
panic("sha3: unrecognized function code")
}
// limit s.buf size to a multiple of s.rate
s.resetBuf()
return &s
}
func (s *asmState) clone() *asmState {
c := *s
c.buf = c.storage[:len(s.buf):cap(s.buf)]
return &c
}
// copyIntoBuf copies b into buf. It will panic if there is not enough space to
// store all of b.
func (s *asmState) copyIntoBuf(b []byte) {
bufLen := len(s.buf)
s.buf = s.buf[:len(s.buf)+len(b)]
copy(s.buf[bufLen:], b)
}
// resetBuf points buf at storage, sets the length to 0 and sets cap to be a
// multiple of the rate.
func (s *asmState) resetBuf() {
max := (cap(s.storage) / s.rate) * s.rate
s.buf = s.storage[:0:max]
}
// Write (via the embedded io.Writer interface) adds more data to the running hash.
// It never returns an error.
func (s *asmState) Write(b []byte) (int, error) {
if s.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
length := len(b)
for len(b) > 0 {
if len(s.buf) == 0 && len(b) >= cap(s.buf) {
// Hash the data directly and push any remaining bytes
// into the buffer.
remainder := len(b) % s.rate
kimd(s.function, &s.a, b[:len(b)-remainder])
if remainder != 0 {
s.copyIntoBuf(b[len(b)-remainder:])
}
return length, nil
}
if len(s.buf) == cap(s.buf) {
// flush the buffer
kimd(s.function, &s.a, s.buf)
s.buf = s.buf[:0]
}
// copy as much as we can into the buffer
n := len(b)
if len(b) > cap(s.buf)-len(s.buf) {
n = cap(s.buf) - len(s.buf)
}
s.copyIntoBuf(b[:n])
b = b[n:]
}
return length, nil
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (s *asmState) Read(out []byte) (n int, err error) {
// The 'compute last message digest' instruction only stores the digest
// at the first operand (dst) for SHAKE functions.
if s.function != shake_128 && s.function != shake_256 {
panic("sha3: can only call Read for SHAKE functions")
}
n = len(out)
// need to pad if we were absorbing
if s.state == spongeAbsorbing {
s.state = spongeSqueezing
// write hash directly into out if possible
if len(out)%s.rate == 0 {
klmd(s.function, &s.a, out, s.buf) // len(out) may be 0
s.buf = s.buf[:0]
return
}
// write hash into buffer
max := cap(s.buf)
if max > len(out) {
max = (len(out)/s.rate)*s.rate + s.rate
}
klmd(s.function, &s.a, s.buf[:max], s.buf)
s.buf = s.buf[:max]
}
for len(out) > 0 {
// flush the buffer
if len(s.buf) != 0 {
c := copy(out, s.buf)
out = out[c:]
s.buf = s.buf[c:]
continue
}
// write hash directly into out if possible
if len(out)%s.rate == 0 {
klmd(s.function|nopad, &s.a, out, nil)
return
}
// write hash into buffer
s.resetBuf()
if cap(s.buf) > len(out) {
s.buf = s.buf[:(len(out)/s.rate)*s.rate+s.rate]
}
klmd(s.function|nopad, &s.a, s.buf, nil)
}
return
}
// Sum appends the current hash to b and returns the resulting slice.
// It does not change the underlying hash state.
func (s *asmState) Sum(b []byte) []byte {
if s.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Copy the state to preserve the original.
a := s.a
// Hash the buffer. Note that we don't clear it because we
// aren't updating the state.
switch s.function {
case sha3_224, sha3_256, sha3_384, sha3_512:
klmd(s.function, &a, nil, s.buf)
return append(b, a[:s.outputLen]...)
case shake_128, shake_256:
d := make([]byte, s.outputLen, 64)
klmd(s.function, &a, d, s.buf)
return append(b, d[:s.outputLen]...)
default:
panic("sha3: unknown function")
}
}
// Reset resets the Hash to its initial state.
func (s *asmState) Reset() {
for i := range s.a {
s.a[i] = 0
}
s.resetBuf()
s.state = spongeAbsorbing
}
// Size returns the number of bytes Sum will return.
func (s *asmState) Size() int {
return s.outputLen
}
// BlockSize returns the hash's underlying block size.
// The Write method must be able to accept any amount
// of data, but it may operate more efficiently if all writes
// are a multiple of the block size.
func (s *asmState) BlockSize() int {
return s.rate
}
// Clone returns a copy of the ShakeHash in its current state.
func (s *asmState) Clone() ShakeHash {
return s.clone()
}
// new224Asm returns an assembly implementation of SHA3-224 if available,
// otherwise it returns nil.
func new224Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_224)
}
return nil
}
// new256Asm returns an assembly implementation of SHA3-256 if available,
// otherwise it returns nil.
func new256Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_256)
}
return nil
}
// new384Asm returns an assembly implementation of SHA3-384 if available,
// otherwise it returns nil.
func new384Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_384)
}
return nil
}
// new512Asm returns an assembly implementation of SHA3-512 if available,
// otherwise it returns nil.
func new512Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_512)
}
return nil
}
// newShake128Asm returns an assembly implementation of SHAKE-128 if available,
// otherwise it returns nil.
func newShake128Asm() ShakeHash {
if cpu.S390X.HasSHA3 {
return newAsmState(shake_128)
}
return nil
}
// newShake256Asm returns an assembly implementation of SHAKE-256 if available,
// otherwise it returns nil.
func newShake256Asm() ShakeHash {
if cpu.S390X.HasSHA3 {
return newAsmState(shake_256)
}
return nil
}

View file

@ -1,33 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
// func kimd(function code, chain *[200]byte, src []byte)
TEXT ·kimd(SB), NOFRAME|NOSPLIT, $0-40
MOVD function+0(FP), R0
MOVD chain+8(FP), R1
LMG src+16(FP), R2, R3 // R2=base, R3=len
continue:
WORD $0xB93E0002 // KIMD --, R2
BVS continue // continue if interrupted
MOVD $0, R0 // reset R0 for pre-go1.8 compilers
RET
// func klmd(function code, chain *[200]byte, dst, src []byte)
TEXT ·klmd(SB), NOFRAME|NOSPLIT, $0-64
// TODO: SHAKE support
MOVD function+0(FP), R0
MOVD chain+8(FP), R1
LMG dst+16(FP), R2, R3 // R2=base, R3=len
LMG src+40(FP), R4, R5 // R4=base, R5=len
continue:
WORD $0xB93F0024 // KLMD R2, R4
BVS continue // continue if interrupted
MOVD $0, R0 // reset R0 for pre-go1.8 compilers
RET

View file

@ -4,19 +4,8 @@
package sha3
// This file defines the ShakeHash interface, and provides
// functions for creating SHAKE and cSHAKE instances, as well as utility
// functions for hashing bytes to arbitrary-length output.
//
//
// SHAKE implementation is based on FIPS PUB 202 [1]
// cSHAKE implementations is based on NIST SP 800-185 [2]
//
// [1] https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf
// [2] https://doi.org/10.6028/NIST.SP.800-185
import (
"encoding/binary"
"crypto/sha3"
"hash"
"io"
)
@ -29,7 +18,7 @@ type ShakeHash interface {
hash.Hash
// Read reads more output from the hash; reading affects the hash's
// state. (ShakeHash.Read is thus very different from Hash.Sum)
// state. (ShakeHash.Read is thus very different from Hash.Sum.)
// It never returns an error, but subsequent calls to Write or Sum
// will panic.
io.Reader
@ -38,97 +27,18 @@ type ShakeHash interface {
Clone() ShakeHash
}
// cSHAKE specific context
type cshakeState struct {
*state // SHA-3 state context and Read/Write operations
// initBlock is the cSHAKE specific initialization set of bytes. It is initialized
// by newCShake function and stores concatenation of N followed by S, encoded
// by the method specified in 3.3 of [1].
// It is stored here in order for Reset() to be able to put context into
// initial state.
initBlock []byte
}
// Consts for configuring initial SHA-3 state
const (
dsbyteShake = 0x1f
dsbyteCShake = 0x04
rate128 = 168
rate256 = 136
)
func bytepad(input []byte, w int) []byte {
// leftEncode always returns max 9 bytes
buf := make([]byte, 0, 9+len(input)+w)
buf = append(buf, leftEncode(uint64(w))...)
buf = append(buf, input...)
padlen := w - (len(buf) % w)
return append(buf, make([]byte, padlen)...)
}
func leftEncode(value uint64) []byte {
var b [9]byte
binary.BigEndian.PutUint64(b[1:], value)
// Trim all but last leading zero bytes
i := byte(1)
for i < 8 && b[i] == 0 {
i++
}
// Prepend number of encoded bytes
b[i-1] = 9 - i
return b[i-1:]
}
func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}
// leftEncode returns max 9 bytes
c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
c.initBlock = append(c.initBlock, leftEncode(uint64(len(N)*8))...)
c.initBlock = append(c.initBlock, N...)
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...)
c.initBlock = append(c.initBlock, S...)
c.Write(bytepad(c.initBlock, c.rate))
return &c
}
// Reset resets the hash to initial state.
func (c *cshakeState) Reset() {
c.state.Reset()
c.Write(bytepad(c.initBlock, c.rate))
}
// Clone returns copy of a cSHAKE context within its current state.
func (c *cshakeState) Clone() ShakeHash {
b := make([]byte, len(c.initBlock))
copy(b, c.initBlock)
return &cshakeState{state: c.clone(), initBlock: b}
}
// Clone returns copy of SHAKE context within its current state.
func (c *state) Clone() ShakeHash {
return c.clone()
}
// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
// Its generic security strength is 128 bits against all attacks if at
// least 32 bytes of its output are used.
func NewShake128() ShakeHash {
if h := newShake128Asm(); h != nil {
return h
}
return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
return &shakeWrapper{sha3.NewSHAKE128(), 32, false, sha3.NewSHAKE128}
}
// NewShake256 creates a new SHAKE256 variable-output-length ShakeHash.
// Its generic security strength is 256 bits against all attacks if
// at least 64 bytes of its output are used.
func NewShake256() ShakeHash {
if h := newShake256Asm(); h != nil {
return h
}
return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
return &shakeWrapper{sha3.NewSHAKE256(), 64, false, sha3.NewSHAKE256}
}
// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@ -138,10 +48,9 @@ func NewShake256() ShakeHash {
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake128.
func NewCShake128(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake128()
}
return newCShake(N, S, rate128, 32, dsbyteCShake)
return &shakeWrapper{sha3.NewCSHAKE128(N, S), 32, false, func() *sha3.SHAKE {
return sha3.NewCSHAKE128(N, S)
}}
}
// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@ -151,10 +60,9 @@ func NewCShake128(N, S []byte) ShakeHash {
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake256.
func NewCShake256(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake256()
}
return newCShake(N, S, rate256, 64, dsbyteCShake)
return &shakeWrapper{sha3.NewCSHAKE256(N, S), 64, false, func() *sha3.SHAKE {
return sha3.NewCSHAKE256(N, S)
}}
}
// ShakeSum128 writes an arbitrary-length digest of data into hash.
@ -170,3 +78,42 @@ func ShakeSum256(hash, data []byte) {
h.Write(data)
h.Read(hash)
}
// shakeWrapper adds the Size, Sum, and Clone methods to a sha3.SHAKE
// to implement the ShakeHash interface.
type shakeWrapper struct {
*sha3.SHAKE
outputLen int
squeezing bool
newSHAKE func() *sha3.SHAKE
}
func (w *shakeWrapper) Read(p []byte) (n int, err error) {
w.squeezing = true
return w.SHAKE.Read(p)
}
func (w *shakeWrapper) Clone() ShakeHash {
s := w.newSHAKE()
b, err := w.MarshalBinary()
if err != nil {
panic(err) // unreachable
}
if err := s.UnmarshalBinary(b); err != nil {
panic(err) // unreachable
}
return &shakeWrapper{s, w.outputLen, w.squeezing, w.newSHAKE}
}
func (w *shakeWrapper) Size() int { return w.outputLen }
func (w *shakeWrapper) Sum(b []byte) []byte {
if w.squeezing {
panic("sha3: Sum after Read")
}
out := make([]byte, w.outputLen)
// Clone the state so that we don't affect future Write calls.
s := w.Clone()
s.Read(out)
return append(b, out...)
}

View file

@ -1,19 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !gc || purego || !s390x
package sha3
// newShake128Asm returns an assembly implementation of SHAKE-128 if available,
// otherwise it returns nil.
func newShake128Asm() ShakeHash {
return nil
}
// newShake256Asm returns an assembly implementation of SHAKE-256 if available,
// otherwise it returns nil.
func newShake256Asm() ShakeHash {
return nil
}

View file

@ -1,23 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (!amd64 && !386 && !ppc64le) || purego
package sha3
// A storageBuf is an aligned array of maxRate bytes.
type storageBuf [maxRate]byte
func (b *storageBuf) asBytes() *[maxRate]byte {
return (*[maxRate]byte)(b)
}
var (
xorIn = xorInGeneric
copyOut = copyOutGeneric
xorInUnaligned = xorInGeneric
copyOutUnaligned = copyOutGeneric
)
const xorImplementationUnaligned = "generic"

View file

@ -1,28 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
import "encoding/binary"
// xorInGeneric xors the bytes in buf into the state; it
// makes no non-portable assumptions about memory layout
// or alignment.
func xorInGeneric(d *state, buf []byte) {
n := len(buf) / 8
for i := 0; i < n; i++ {
a := binary.LittleEndian.Uint64(buf)
d.a[i] ^= a
buf = buf[8:]
}
}
// copyOutGeneric copies uint64s to a byte buffer.
func copyOutGeneric(d *state, b []byte) {
for i := 0; len(b) >= 8; i++ {
binary.LittleEndian.PutUint64(b, d.a[i])
b = b[8:]
}
}

View file

@ -1,66 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (amd64 || 386 || ppc64le) && !purego
package sha3
import "unsafe"
// A storageBuf is an aligned array of maxRate bytes.
type storageBuf [maxRate / 8]uint64
func (b *storageBuf) asBytes() *[maxRate]byte {
return (*[maxRate]byte)(unsafe.Pointer(b))
}
// xorInUnaligned uses unaligned reads and writes to update d.a to contain d.a
// XOR buf.
func xorInUnaligned(d *state, buf []byte) {
n := len(buf)
bw := (*[maxRate / 8]uint64)(unsafe.Pointer(&buf[0]))[: n/8 : n/8]
if n >= 72 {
d.a[0] ^= bw[0]
d.a[1] ^= bw[1]
d.a[2] ^= bw[2]
d.a[3] ^= bw[3]
d.a[4] ^= bw[4]
d.a[5] ^= bw[5]
d.a[6] ^= bw[6]
d.a[7] ^= bw[7]
d.a[8] ^= bw[8]
}
if n >= 104 {
d.a[9] ^= bw[9]
d.a[10] ^= bw[10]
d.a[11] ^= bw[11]
d.a[12] ^= bw[12]
}
if n >= 136 {
d.a[13] ^= bw[13]
d.a[14] ^= bw[14]
d.a[15] ^= bw[15]
d.a[16] ^= bw[16]
}
if n >= 144 {
d.a[17] ^= bw[17]
}
if n >= 168 {
d.a[18] ^= bw[18]
d.a[19] ^= bw[19]
d.a[20] ^= bw[20]
}
}
func copyOutUnaligned(d *state, buf []byte) {
ab := (*[maxRate]uint8)(unsafe.Pointer(&d.a[0]))
copy(buf, ab[:])
}
var (
xorIn = xorInUnaligned
copyOut = copyOutUnaligned
)
const xorImplementationUnaligned = "unaligned"

4
vendor/golang.org/x/mod/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

View file

@ -96,10 +96,11 @@ package module
// Changes to the semantics in this file require approval from rsc.
import (
"cmp"
"errors"
"fmt"
"path"
"sort"
"slices"
"strings"
"unicode"
"unicode/utf8"
@ -260,7 +261,7 @@ func modPathOK(r rune) bool {
// importPathOK reports whether r can appear in a package import path element.
//
// Import paths are intermediate between module paths and file paths: we allow
// Import paths are intermediate between module paths and file paths: we
// disallow characters that would be confusing or ambiguous as arguments to
// 'go get' (such as '@' and ' ' ), but allow certain characters that are
// otherwise-unambiguous on the command line and historically used for some
@ -657,17 +658,15 @@ func CanonicalVersion(v string) string {
// optionally followed by a tie-breaking suffix introduced by a slash character,
// like in "v0.0.1/go.mod".
func Sort(list []Version) {
sort.Slice(list, func(i, j int) bool {
mi := list[i]
mj := list[j]
if mi.Path != mj.Path {
return mi.Path < mj.Path
slices.SortFunc(list, func(i, j Version) int {
if i.Path != j.Path {
return strings.Compare(i.Path, j.Path)
}
// To help go.sum formatting, allow version/file.
// Compare semver prefix by semver rules,
// file by string order.
vi := mi.Version
vj := mj.Version
vi := i.Version
vj := j.Version
var fi, fj string
if k := strings.Index(vi, "/"); k >= 0 {
vi, fi = vi[:k], vi[k:]
@ -676,9 +675,9 @@ func Sort(list []Version) {
vj, fj = vj[:k], vj[k:]
}
if vi != vj {
return semver.Compare(vi, vj) < 0
return semver.Compare(vi, vj)
}
return fi < fj
return cmp.Compare(fi, fj)
})
}
@ -803,8 +802,8 @@ func MatchPrefixPatterns(globs, target string) bool {
for globs != "" {
// Extract next non-empty glob in comma-separated list.
var glob string
if i := strings.Index(globs, ","); i >= 0 {
glob, globs = globs[:i], globs[i+1:]
if before, after, ok := strings.Cut(globs, ","); ok {
glob, globs = before, after
} else {
glob, globs = globs, ""
}

View file

@ -22,7 +22,10 @@
// as shorthands for vMAJOR.0.0 and vMAJOR.MINOR.0.
package semver
import "sort"
import (
"slices"
"strings"
)
// parsed returns the parsed form of a semantic version string.
type parsed struct {
@ -42,8 +45,8 @@ func IsValid(v string) bool {
// Canonical returns the canonical formatting of the semantic version v.
// It fills in any missing .MINOR or .PATCH and discards build metadata.
// Two semantic versions compare equal only if their canonical formattings
// are identical strings.
// Two semantic versions compare equal only if their canonical formatting
// is an identical string.
// The canonical invalid semantic version is the empty string.
func Canonical(v string) string {
p, ok := parse(v)
@ -154,19 +157,22 @@ func Max(v, w string) string {
// ByVersion implements [sort.Interface] for sorting semantic version strings.
type ByVersion []string
func (vs ByVersion) Len() int { return len(vs) }
func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] }
func (vs ByVersion) Less(i, j int) bool {
cmp := Compare(vs[i], vs[j])
if cmp != 0 {
return cmp < 0
}
return vs[i] < vs[j]
func (vs ByVersion) Len() int { return len(vs) }
func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] }
func (vs ByVersion) Less(i, j int) bool { return compareVersion(vs[i], vs[j]) < 0 }
// Sort sorts a list of semantic version strings using [Compare] and falls back
// to use [strings.Compare] if both versions are considered equal.
func Sort(list []string) {
slices.SortFunc(list, compareVersion)
}
// Sort sorts a list of semantic version strings using [ByVersion].
func Sort(list []string) {
sort.Sort(ByVersion(list))
func compareVersion(a, b string) int {
cmp := Compare(a, b)
if cmp != 0 {
return cmp
}
return strings.Compare(a, b)
}
func parse(v string) (p parsed, ok bool) {

4
vendor/golang.org/x/net/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

File diff suppressed because it is too large Load diff

View file

@ -78,16 +78,11 @@ example, to process each anchor node in depth-first order:
if err != nil {
// ...
}
var f func(*html.Node)
f = func(n *html.Node) {
for n := range doc.Descendants() {
if n.Type == html.ElementNode && n.Data == "a" {
// Do something with n...
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
f(c)
}
}
f(doc)
The relevant specifications include:
https://html.spec.whatwg.org/multipage/syntax.html and

View file

@ -87,7 +87,7 @@ func parseDoctype(s string) (n *Node, quirks bool) {
}
}
if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" &&
strings.ToLower(lastAttr.Val) == "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd" {
strings.EqualFold(lastAttr.Val, "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd") {
quirks = true
}
}

View file

@ -299,7 +299,7 @@ func escape(w writer, s string) error {
case '\r':
esc = "&#13;"
default:
panic("unrecognized escape character")
panic("html: unrecognized escape character")
}
s = s[i+1:]
if _, err := w.WriteString(esc); err != nil {

View file

@ -40,8 +40,7 @@ func htmlIntegrationPoint(n *Node) bool {
if n.Data == "annotation-xml" {
for _, a := range n.Attr {
if a.Key == "encoding" {
val := strings.ToLower(a.Val)
if val == "text/html" || val == "application/xhtml+xml" {
if strings.EqualFold(a.Val, "text/html") || strings.EqualFold(a.Val, "application/xhtml+xml") {
return true
}
}

56
vendor/golang.org/x/net/html/iter.go generated vendored Normal file
View file

@ -0,0 +1,56 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.23
package html
import "iter"
// Ancestors returns an iterator over the ancestors of n, starting with n.Parent.
//
// Mutating a Node or its parents while iterating may have unexpected results.
func (n *Node) Ancestors() iter.Seq[*Node] {
_ = n.Parent // eager nil check
return func(yield func(*Node) bool) {
for p := n.Parent; p != nil && yield(p); p = p.Parent {
}
}
}
// ChildNodes returns an iterator over the immediate children of n,
// starting with n.FirstChild.
//
// Mutating a Node or its children while iterating may have unexpected results.
func (n *Node) ChildNodes() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
for c := n.FirstChild; c != nil && yield(c); c = c.NextSibling {
}
}
}
// Descendants returns an iterator over all nodes recursively beneath
// n, excluding n itself. Nodes are visited in depth-first preorder.
//
// Mutating a Node or its descendants while iterating may have unexpected results.
func (n *Node) Descendants() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
n.descendants(yield)
}
}
func (n *Node) descendants(yield func(*Node) bool) bool {
for c := range n.ChildNodes() {
if !yield(c) || !c.descendants(yield) {
return false
}
}
return true
}

View file

@ -11,6 +11,7 @@ import (
// A NodeType is the type of a Node.
type NodeType uint32
//go:generate stringer -type NodeType
const (
ErrorNode NodeType = iota
TextNode
@ -38,6 +39,10 @@ var scopeMarker = Node{Type: scopeMarkerNode}
// that it looks like "a<b" rather than "a&lt;b". For element nodes, DataAtom
// is the atom for Data, or zero if Data is not a known tag name.
//
// Node trees may be navigated using the link fields (Parent,
// FirstChild, and so on) or a range loop over iterators such as
// [Node.Descendants].
//
// An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace.
// Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and
// "svg" is short for "http://www.w3.org/2000/svg".

31
vendor/golang.org/x/net/html/nodetype_string.go generated vendored Normal file
View file

@ -0,0 +1,31 @@
// Code generated by "stringer -type NodeType"; DO NOT EDIT.
package html
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ErrorNode-0]
_ = x[TextNode-1]
_ = x[DocumentNode-2]
_ = x[ElementNode-3]
_ = x[CommentNode-4]
_ = x[DoctypeNode-5]
_ = x[RawNode-6]
_ = x[scopeMarkerNode-7]
}
const _NodeType_name = "ErrorNodeTextNodeDocumentNodeElementNodeCommentNodeDoctypeNodeRawNodescopeMarkerNode"
var _NodeType_index = [...]uint8{0, 9, 17, 29, 40, 51, 62, 69, 84}
func (i NodeType) String() string {
idx := int(i) - 0
if i < 0 || idx >= len(_NodeType_index)-1 {
return "NodeType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _NodeType_name[_NodeType_index[idx]:_NodeType_index[idx+1]]
}

View file

@ -136,7 +136,7 @@ func (p *parser) indexOfElementInScope(s scope, matchTags ...a.Atom) int {
return -1
}
default:
panic("unreachable")
panic(fmt.Sprintf("html: internal error: indexOfElementInScope unknown scope: %d", s))
}
}
switch s {
@ -179,7 +179,7 @@ func (p *parser) clearStackToContext(s scope) {
return
}
default:
panic("unreachable")
panic(fmt.Sprintf("html: internal error: clearStackToContext unknown scope: %d", s))
}
}
}
@ -231,7 +231,14 @@ func (p *parser) addChild(n *Node) {
}
if n.Type == ElementNode {
p.oe = append(p.oe, n)
p.insertOpenElement(n)
}
}
func (p *parser) insertOpenElement(n *Node) {
p.oe = append(p.oe, n)
if len(p.oe) > 512 {
panic("html: open stack of elements exceeds 512 nodes")
}
}
@ -810,7 +817,7 @@ func afterHeadIM(p *parser) bool {
p.im = inFramesetIM
return true
case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title:
p.oe = append(p.oe, p.head)
p.insertOpenElement(p.head)
defer p.oe.remove(p.head)
return inHeadIM(p)
case a.Head:
@ -840,6 +847,10 @@ func afterHeadIM(p *parser) bool {
p.parseImpliedToken(StartTagToken, a.Body, a.Body.String())
p.framesetOK = true
if p.tok.Type == ErrorToken {
// Stop parsing.
return true
}
return false
}
@ -920,7 +931,7 @@ func inBodyIM(p *parser) bool {
p.addElement()
p.im = inFramesetIM
return true
case a.Address, a.Article, a.Aside, a.Blockquote, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Main, a.Menu, a.Nav, a.Ol, a.P, a.Section, a.Summary, a.Ul:
case a.Address, a.Article, a.Aside, a.Blockquote, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Main, a.Menu, a.Nav, a.Ol, a.P, a.Search, a.Section, a.Summary, a.Ul:
p.popUntil(buttonScope, a.P)
p.addElement()
case a.H1, a.H2, a.H3, a.H4, a.H5, a.H6:
@ -1031,7 +1042,7 @@ func inBodyIM(p *parser) bool {
if p.tok.DataAtom == a.Input {
for _, t := range p.tok.Attr {
if t.Key == "type" {
if strings.ToLower(t.Val) == "hidden" {
if strings.EqualFold(t.Val, "hidden") {
// Skip setting framesetOK = false
return true
}
@ -1132,7 +1143,7 @@ func inBodyIM(p *parser) bool {
return false
}
return true
case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Main, a.Menu, a.Nav, a.Ol, a.Pre, a.Section, a.Summary, a.Ul:
case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Main, a.Menu, a.Nav, a.Ol, a.Pre, a.Search, a.Section, a.Summary, a.Ul:
p.popUntil(defaultScope, p.tok.DataAtom)
case a.Form:
if p.oe.contains(a.Template) {
@ -1459,7 +1470,7 @@ func inTableIM(p *parser) bool {
return inHeadIM(p)
case a.Input:
for _, t := range p.tok.Attr {
if t.Key == "type" && strings.ToLower(t.Val) == "hidden" {
if t.Key == "type" && strings.EqualFold(t.Val, "hidden") {
p.addElement()
p.oe.pop()
return true
@ -1674,7 +1685,7 @@ func inTableBodyIM(p *parser) bool {
return inTableIM(p)
}
// Section 12.2.6.4.14.
// Section 13.2.6.4.14.
func inRowIM(p *parser) bool {
switch p.tok.Type {
case StartTagToken:
@ -1686,7 +1697,9 @@ func inRowIM(p *parser) bool {
p.im = inCellIM
return true
case a.Caption, a.Col, a.Colgroup, a.Tbody, a.Tfoot, a.Thead, a.Tr:
if p.popUntil(tableScope, a.Tr) {
if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM
return false
}
@ -1696,22 +1709,28 @@ func inRowIM(p *parser) bool {
case EndTagToken:
switch p.tok.DataAtom {
case a.Tr:
if p.popUntil(tableScope, a.Tr) {
if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM
return true
}
// Ignore the token.
return true
case a.Table:
if p.popUntil(tableScope, a.Tr) {
if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM
return false
}
// Ignore the token.
return true
case a.Tbody, a.Tfoot, a.Thead:
if p.elementInScope(tableScope, p.tok.DataAtom) {
p.parseImpliedToken(EndTagToken, a.Tr, a.Tr.String())
if p.elementInScope(tableScope, p.tok.DataAtom) && p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM
return false
}
// Ignore the token.
@ -2218,16 +2237,20 @@ func parseForeignContent(p *parser) bool {
p.acknowledgeSelfClosingTag()
}
case EndTagToken:
if strings.EqualFold(p.oe[len(p.oe)-1].Data, p.tok.Data) {
p.oe = p.oe[:len(p.oe)-1]
return true
}
for i := len(p.oe) - 1; i >= 0; i-- {
if p.oe[i].Namespace == "" {
return p.im(p)
}
if strings.EqualFold(p.oe[i].Data, p.tok.Data) {
p.oe = p.oe[:i]
return true
}
if i > 0 && p.oe[i-1].Namespace == "" {
break
}
}
return true
return p.im(p)
default:
// Ignore the token.
}
@ -2308,9 +2331,13 @@ func (p *parser) parseCurrentToken() {
}
}
func (p *parser) parse() error {
func (p *parser) parse() (err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = fmt.Errorf("%s", panicErr)
}
}()
// Iterate until EOF. Any other error will cause an early return.
var err error
for err != io.EOF {
// CDATA sections are allowed only in foreign content.
n := p.oe.top()
@ -2339,6 +2366,8 @@ func (p *parser) parse() error {
// <tag>s. Conversely, explicit <tag>s in r's data can be silently dropped,
// with no corresponding node in the resulting tree.
//
// Parse will reject HTML that is nested deeper than 512 elements.
//
// The input is assumed to be UTF-8 encoded.
func Parse(r io.Reader) (*Node, error) {
return ParseWithOptions(r)

View file

@ -184,7 +184,7 @@ func render1(w writer, n *Node) error {
return err
}
// Add initial newline where there is danger of a newline beging ignored.
// Add initial newline where there is danger of a newline being ignored.
if c := n.FirstChild; c != nil && c.Type == TextNode && strings.HasPrefix(c.Data, "\n") {
switch n.Data {
case "pre", "listing", "textarea":

View file

@ -839,8 +839,22 @@ func (z *Tokenizer) readStartTag() TokenType {
if raw {
z.rawTag = strings.ToLower(string(z.buf[z.data.start:z.data.end]))
}
// Look for a self-closing token like "<br/>".
if z.err == nil && z.buf[z.raw.end-2] == '/' {
// Look for a self-closing token (e.g. <br/>).
//
// Originally, we did this by just checking that the last character of the
// tag (ignoring the closing bracket) was a solidus (/) character, but this
// is not always accurate.
//
// We need to be careful that we don't misinterpret a non-self-closing tag
// as self-closing, as can happen if the tag contains unquoted attribute
// values (i.e. <p a=/>).
//
// To avoid this, we check that the last non-bracket character of the tag
// (z.raw.end-2) isn't the same character as the last non-quote character of
// the last attribute of the tag (z.pendingAttr[1].end-1), if the tag has
// attributes.
nAttrs := len(z.attr)
if z.err == nil && z.buf[z.raw.end-2] == '/' && (nAttrs == 0 || z.raw.end-2 != z.attr[nAttrs-1][1].end-1) {
return SelfClosingTagToken
}
return StartTagToken

View file

@ -8,8 +8,8 @@ package http2
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
)
@ -158,7 +158,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) {
// This code decides which ones live or die.
// The return value used is whether c was used.
// c is never closed.
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
p.mu.Lock()
for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() {
@ -194,8 +194,8 @@ type addConnCall struct {
err error
}
func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
cc, err := t.NewClientConn(tc)
func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
cc, err := t.NewClientConn(nc)
p := c.p
p.mu.Lock()

20
vendor/golang.org/x/net/http2/client_priority_go126.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.27
package http2
import "net/http"
// Support for go.dev/issue/75500 is added in Go 1.27. In case anyone uses
// x/net with versions before Go 1.27, we return true here so that their write
// scheduler will still be the round-robin write scheduler rather than the RFC
// 9218 write scheduler. That way, older users of Go will not see a sudden
// change of behavior just from importing x/net.
//
// TODO(nsh): remove this file after x/net go.mod is at Go 1.27.
func clientPriorityDisabled(_ *http.Server) bool {
return true
}

13
vendor/golang.org/x/net/http2/client_priority_go127.go generated vendored Normal file
View file

@ -0,0 +1,13 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.27
package http2
import "net/http"
func clientPriorityDisabled(s *http.Server) bool {
return s.DisableClientPriority
}

169
vendor/golang.org/x/net/http2/config.go generated vendored Normal file
View file

@ -0,0 +1,169 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"math"
"net/http"
"time"
)
// http2Config is a package-internal version of net/http.HTTP2Config.
//
// http.HTTP2Config was added in Go 1.24.
// When running with a version of net/http that includes HTTP2Config,
// we merge the configuration with the fields in Transport or Server
// to produce an http2Config.
//
// Zero valued fields in http2Config are interpreted as in the
// net/http.HTTPConfig documentation.
//
// Precedence order for reconciling configurations is:
//
// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
// - Otherwise use the http2.{Server.Transport} value.
// - If the resulting value is zero or out of range, use a default.
type http2Config struct {
MaxConcurrentStreams uint32
StrictMaxConcurrentRequests bool
MaxDecoderHeaderTableSize uint32
MaxEncoderHeaderTableSize uint32
MaxReadFrameSize uint32
MaxUploadBufferPerConnection int32
MaxUploadBufferPerStream int32
SendPingTimeout time.Duration
PingTimeout time.Duration
WriteByteTimeout time.Duration
PermitProhibitedCipherSuites bool
CountError func(errType string)
}
// configFromServer merges configuration settings from
// net/http.Server.HTTP2Config and http2.Server.
func configFromServer(h1 *http.Server, h2 *Server) http2Config {
conf := http2Config{
MaxConcurrentStreams: h2.MaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
CountError: h2.CountError,
}
fillNetHTTPConfig(&conf, h1.HTTP2)
setConfigDefaults(&conf, true)
return conf
}
// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2
// (the net/http Transport).
func configFromTransport(h2 *Transport) http2Config {
conf := http2Config{
StrictMaxConcurrentRequests: h2.StrictMaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
}
// Unlike most config fields, where out-of-range values revert to the default,
// Transport.MaxReadFrameSize clips.
if conf.MaxReadFrameSize < minMaxFrameSize {
conf.MaxReadFrameSize = minMaxFrameSize
} else if conf.MaxReadFrameSize > maxFrameSize {
conf.MaxReadFrameSize = maxFrameSize
}
if h2.t1 != nil {
fillNetHTTPConfig(&conf, h2.t1.HTTP2)
}
setConfigDefaults(&conf, false)
return conf
}
func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
if *v < minval || *v > maxval {
*v = defval
}
}
func setConfigDefaults(conf *http2Config, server bool) {
setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
if server {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
}
if server {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
}
setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
}
// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
func adjustHTTP1MaxHeaderSize(n int64) int64 {
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return n + typicalHeaders*perFieldOverhead
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if http2ConfigStrictMaxConcurrentRequests(h2) {
conf.StrictMaxConcurrentRequests = true
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}

15
vendor/golang.org/x/net/http2/config_go125.go generated vendored Normal file
View file

@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return false
}

15
vendor/golang.org/x/net/http2/config_go126.go generated vendored Normal file
View file

@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return h2.StrictMaxConcurrentRequests
}

View file

@ -11,11 +11,13 @@ import (
"fmt"
"io"
"log"
"slices"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpsfv"
)
const frameHeaderLen = 9
@ -23,40 +25,43 @@ const frameHeaderLen = 9
var padZeros = make([]byte, 255) // zeros for padding
// A FrameType is a registered frame type as defined in
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 and other future
// RFCs.
type FrameType uint8
const (
FrameData FrameType = 0x0
FrameHeaders FrameType = 0x1
FramePriority FrameType = 0x2
FrameRSTStream FrameType = 0x3
FrameSettings FrameType = 0x4
FramePushPromise FrameType = 0x5
FramePing FrameType = 0x6
FrameGoAway FrameType = 0x7
FrameWindowUpdate FrameType = 0x8
FrameContinuation FrameType = 0x9
FrameData FrameType = 0x0
FrameHeaders FrameType = 0x1
FramePriority FrameType = 0x2
FrameRSTStream FrameType = 0x3
FrameSettings FrameType = 0x4
FramePushPromise FrameType = 0x5
FramePing FrameType = 0x6
FrameGoAway FrameType = 0x7
FrameWindowUpdate FrameType = 0x8
FrameContinuation FrameType = 0x9
FramePriorityUpdate FrameType = 0x10
)
var frameName = map[FrameType]string{
FrameData: "DATA",
FrameHeaders: "HEADERS",
FramePriority: "PRIORITY",
FrameRSTStream: "RST_STREAM",
FrameSettings: "SETTINGS",
FramePushPromise: "PUSH_PROMISE",
FramePing: "PING",
FrameGoAway: "GOAWAY",
FrameWindowUpdate: "WINDOW_UPDATE",
FrameContinuation: "CONTINUATION",
var frameNames = [...]string{
FrameData: "DATA",
FrameHeaders: "HEADERS",
FramePriority: "PRIORITY",
FrameRSTStream: "RST_STREAM",
FrameSettings: "SETTINGS",
FramePushPromise: "PUSH_PROMISE",
FramePing: "PING",
FrameGoAway: "GOAWAY",
FrameWindowUpdate: "WINDOW_UPDATE",
FrameContinuation: "CONTINUATION",
FramePriorityUpdate: "PRIORITY_UPDATE",
}
func (t FrameType) String() string {
if s, ok := frameName[t]; ok {
return s
if int(t) < len(frameNames) {
return frameNames[t]
}
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", t)
}
// Flags is a bitmask of HTTP/2 flags.
@ -124,22 +129,23 @@ var flagName = map[FrameType]map[Flags]string{
// might be 0).
type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error)
var frameParsers = map[FrameType]frameParser{
FrameData: parseDataFrame,
FrameHeaders: parseHeadersFrame,
FramePriority: parsePriorityFrame,
FrameRSTStream: parseRSTStreamFrame,
FrameSettings: parseSettingsFrame,
FramePushPromise: parsePushPromise,
FramePing: parsePingFrame,
FrameGoAway: parseGoAwayFrame,
FrameWindowUpdate: parseWindowUpdateFrame,
FrameContinuation: parseContinuationFrame,
var frameParsers = [...]frameParser{
FrameData: parseDataFrame,
FrameHeaders: parseHeadersFrame,
FramePriority: parsePriorityFrame,
FrameRSTStream: parseRSTStreamFrame,
FrameSettings: parseSettingsFrame,
FramePushPromise: parsePushPromise,
FramePing: parsePingFrame,
FrameGoAway: parseGoAwayFrame,
FrameWindowUpdate: parseWindowUpdateFrame,
FrameContinuation: parseContinuationFrame,
FramePriorityUpdate: parsePriorityUpdateFrame,
}
func typeFrameParser(t FrameType) frameParser {
if f := frameParsers[t]; f != nil {
return f
if int(t) < len(frameParsers) {
return frameParsers[t]
}
return parseUnknownFrame
}
@ -225,6 +231,11 @@ var fhBytes = sync.Pool{
},
}
func invalidHTTP1LookingFrameHeader() FrameHeader {
fh, _ := readFrameHeader(make([]byte, frameHeaderLen), strings.NewReader("HTTP/1.1 "))
return fh
}
// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader.
// Most users should use Framer.ReadFrame instead.
func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
@ -275,6 +286,8 @@ type Framer struct {
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
// lastFrameType holds the type of the last frame for verifying frame order.
lastFrameType FrameType
maxReadSize uint32
headerBuf [frameHeaderLen]byte
@ -342,7 +355,7 @@ func (fr *Framer) maxHeaderListSize() uint32 {
func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
// Write the FrameHeader.
f.wbuf = append(f.wbuf[:0],
0, // 3 bytes of length, filled in in endWrite
0, // 3 bytes of length, filled in endWrite
0,
0,
byte(ftype),
@ -483,30 +496,47 @@ func terminalReadFrameError(err error) bool {
return err != nil
}
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame.
// ReadFrameHeader reads the header of the next frame.
// It reads the 9-byte fixed frame header, and does not read any portion of the
// frame payload. The caller is responsible for consuming the payload, either
// with ReadFrameForHeader or directly from the Framer's io.Reader.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
// If the frame is larger than previously set with SetMaxReadFrameSize, it
// returns the frame header and ErrFrameTooLarge.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
// If the returned FrameHeader.StreamID is non-zero, it indicates the stream
// responsible for the error.
func (fr *Framer) ReadFrameHeader() (FrameHeader, error) {
fr.errDetail = nil
fh, err := readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return fh, err
}
if fh.Length > fr.maxReadSize {
if fh == invalidHTTP1LookingFrameHeader() {
return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge)
}
return fh, ErrFrameTooLarge
}
if err := fr.checkFrameOrder(fh); err != nil {
return fh, err
}
return fh, nil
}
// ReadFrameForHeader reads the payload for the frame with the given FrameHeader.
//
// It behaves identically to ReadFrame, other than not checking the maximum
// frame size.
func (fr *Framer) ReadFrameForHeader(fh FrameHeader) (Frame, error) {
if fr.lastFrame != nil {
fr.lastFrame.invalidate()
}
fh, err := readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return nil, err
}
if fh.Length > fr.maxReadSize {
return nil, ErrFrameTooLarge
}
payload := fr.getReadBuf(fh.Length)
if _, err := io.ReadFull(fr.r, payload); err != nil {
if fh == invalidHTTP1LookingFrameHeader() {
return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", err)
}
return nil, err
}
f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
@ -516,9 +546,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
}
return nil, err
}
if err := fr.checkFrameOrder(f); err != nil {
return nil, err
}
fr.lastFrame = f
if fr.logReads {
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f))
}
@ -528,6 +556,24 @@ func (fr *Framer) ReadFrame() (Frame, error) {
return f, nil
}
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame or ReadFrameBodyForHeader.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
fh, err := fr.ReadFrameHeader()
if err != nil {
return nil, err
}
return fr.ReadFrameForHeader(fh)
}
// connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug
@ -540,20 +586,19 @@ func (fr *Framer) connError(code ErrCode, reason string) error {
// checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous.
func (fr *Framer) checkFrameOrder(f Frame) error {
last := fr.lastFrame
fr.lastFrame = f
func (fr *Framer) checkFrameOrder(fh FrameHeader) error {
lastType := fr.lastFrameType
fr.lastFrameType = fh.Type
if fr.AllowIllegalReads {
return nil
}
fh := f.Header()
if fr.lastHeaderStream != 0 {
if fh.Type != FrameContinuation {
return fr.connError(ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID,
last.Header().Type, fr.lastHeaderStream))
lastType, fr.lastHeaderStream))
}
if fh.StreamID != fr.lastHeaderStream {
return fr.connError(ErrCodeProtocol,
@ -1141,7 +1186,41 @@ type PriorityFrame struct {
PriorityParam
}
// PriorityParam are the stream prioritzation parameters.
// defaultRFC9218Priority determines what priority we should use as the default
// value.
//
// According to RFC 9218, by default, streams should be given an urgency of 3
// and should be non-incremental. However, making streams non-incremental by
// default would be a huge change to our historical behavior where we would
// round-robin writes across streams. When streams are non-incremental, we
// would process streams of the same urgency one-by-one to completion instead.
//
// To avoid such a sudden change which might break some HTTP/2 users, this
// function allows the caller to specify whether they can actually use the
// default value as specified in RFC 9218. If not, this function will return a
// priority value where streams are incremental by default instead: effectively
// a round-robin between stream of the same urgency.
//
// As an example, a server might not be able to use the RFC 9218 default value
// when it's not sure that the client it is serving is aware of RFC 9218.
func defaultRFC9218Priority(canUseDefault bool) PriorityParam {
if canUseDefault {
return PriorityParam{
urgency: 3,
incremental: 0,
}
}
return PriorityParam{
urgency: 3,
incremental: 1,
}
}
// Note that HTTP/2 has had two different prioritization schemes, and
// PriorityParam struct below is a superset of both schemes. The exported
// symbols are from RFC 7540 and the non-exported ones are from RFC 9218.
// PriorityParam are the stream prioritization parameters.
type PriorityParam struct {
// StreamDep is a 31-bit stream identifier for the
// stream that this stream depends on. Zero means no
@ -1156,6 +1235,20 @@ type PriorityParam struct {
// the spec, "Add one to the value to obtain a weight between
// 1 and 256."
Weight uint8
// "The urgency (u) parameter value is Integer (see Section 3.3.1 of
// [STRUCTURED-FIELDS]), between 0 and 7 inclusive, in descending order of
// priority. The default is 3."
urgency uint8
// "The incremental (i) parameter value is Boolean (see Section 3.3.6 of
// [STRUCTURED-FIELDS]). It indicates if an HTTP response can be processed
// incrementally, i.e., provide some meaningful output as chunks of the
// response arrive."
//
// We use uint8 (i.e. 0 is false, 1 is true) instead of bool so we can
// avoid unnecessary type conversions and because either type takes 1 byte.
incremental uint8
}
func (p PriorityParam) IsZero() bool {
@ -1204,6 +1297,74 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
return f.endWrite()
}
// PriorityUpdateFrame is a PRIORITY_UPDATE frame as described in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
type PriorityUpdateFrame struct {
FrameHeader
Priority string
PrioritizedStreamID uint32
}
func parseRFC9218Priority(s string, canUseDefault bool) (p PriorityParam, ok bool) {
p = defaultRFC9218Priority(canUseDefault)
ok = httpsfv.ParseDictionary(s, func(key, val, _ string) {
switch key {
case "u":
if u, ok := httpsfv.ParseInteger(val); ok && u >= 0 && u <= 7 {
p.urgency = uint8(u)
}
case "i":
if i, ok := httpsfv.ParseBoolean(val); ok {
if i {
p.incremental = 1
} else {
p.incremental = 0
}
}
}
})
if !ok {
return defaultRFC9218Priority(canUseDefault), ok
}
return p, true
}
func parsePriorityUpdateFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if fh.StreamID != 0 {
countError("frame_priority_update_non_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with non-zero stream ID"}
}
if len(payload) < 4 {
countError("frame_priority_update_bad_length")
return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY_UPDATE frame payload size was %d; want at least 4", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
streamID := v & 0x7fffffff // mask off high bit
if streamID == 0 {
countError("frame_priority_update_prioritizing_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with prioritized stream ID of zero"}
}
return &PriorityUpdateFrame{
FrameHeader: fh,
PrioritizedStreamID: streamID,
Priority: string(payload[4:]),
}, nil
}
// WritePriorityUpdate writes a PRIORITY_UPDATE frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *Framer) WritePriorityUpdate(streamID uint32, priority string) error {
if !validStreamID(streamID) && !f.AllowIllegalWrites {
return errStreamID
}
f.startWrite(FramePriorityUpdate, 0, 0)
f.writeUint32(streamID)
f.writeBytes([]byte(priority))
return f.endWrite()
}
// A RSTStreamFrame allows for abnormal termination of a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
type RSTStreamFrame struct {
@ -1485,12 +1646,29 @@ func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
return mh.Fields
}
func (mh *MetaHeadersFrame) rfc9218Priority(priorityAware bool) (p PriorityParam, priorityAwareAfter, hasIntermediary bool) {
var s string
for _, field := range mh.Fields {
if field.Name == "priority" {
s = field.Value
priorityAware = true
}
if slices.Contains([]string{"via", "forwarded", "x-forwarded-for"}, field.Name) {
hasIntermediary = true
}
}
// No need to check for ok. parseRFC9218Priority will return a default
// value if there is no priority field or if the field cannot be parsed.
p, _ = parseRFC9218Priority(s, priorityAware && !hasIntermediary)
return p, priorityAware, hasIntermediary
}
func (mh *MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
case ":method", ":path", ":scheme", ":authority":
case ":method", ":path", ":scheme", ":authority", ":protocol":
isRequest = true
case ":status":
isResponse = true
@ -1498,7 +1676,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
return pseudoHeaderError(hf.Name)
}
// Check for duplicates.
// This would be a bad algorithm, but N is 4.
// This would be a bad algorithm, but N is 5.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {

View file

@ -15,21 +15,32 @@ import (
"runtime"
"strconv"
"sync"
"sync/atomic"
)
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
// Setting DebugGoroutines to false during a test to disable goroutine debugging
// results in race detector complaints when a test leaves goroutines running before
// returning. Tests shouldn't do this, of course, but when they do it generally shows
// up as infrequent, hard-to-debug flakes. (See #66519.)
//
// Disable goroutine debugging during individual tests with an atomic bool.
// (Note that it's safe to enable/disable debugging mid-test, so the actual race condition
// here is harmless.)
var disableDebugGoroutines atomic.Bool
type goroutineLock uint64
func newGoroutineLock() goroutineLock {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return 0
}
return goroutineLock(curGoroutineID())
}
func (g goroutineLock) check() {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return
}
if curGoroutineID() != uint64(g) {
@ -38,7 +49,7 @@ func (g goroutineLock) check() {
}
func (g goroutineLock) checkNotOn() {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return
}
if curGoroutineID() == uint64(g) {

View file

@ -132,11 +132,8 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// of the body, and reforward the client preface on the net.Conn this function
// creates.
func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) {
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, errors.New("h2c: connection does not support Hijack")
}
conn, rw, err := hijacker.Hijack()
rc := http.NewResponseController(w)
conn, rw, err := rc.Hijack()
if err != nil {
return nil, err
}
@ -163,10 +160,6 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
if err != nil {
return nil, nil, err
}
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, nil, errors.New("h2c: connection does not support Hijack")
}
body, err := io.ReadAll(r.Body)
if err != nil {
@ -174,7 +167,8 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
}
r.Body = io.NopCloser(bytes.NewBuffer(body))
conn, rw, err := hijacker.Hijack()
rc := http.NewResponseController(w)
conn, rw, err := rc.Hijack()
if err != nil {
return nil, nil, err
}

View file

@ -11,21 +11,21 @@
// requires Go 1.6 or later)
//
// See https://http2.github.io/ for more information on HTTP/2.
//
// See https://http2.golang.org/ for a test server running this code.
package http2 // import "golang.org/x/net/http2"
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
@ -34,7 +34,15 @@ var (
VerboseLogs bool
logFrameWrites bool
logFrameReads bool
inTests bool
// Enabling extended CONNECT by causes browsers to attempt to use
// WebSockets-over-HTTP/2. This results in problems when the server's websocket
// package doesn't support extended CONNECT.
//
// Disable extended CONNECT by default for now.
//
// Issue #71128.
disableExtendedConnectProtocol = true
)
func init() {
@ -47,6 +55,9 @@ func init() {
logFrameWrites = true
logFrameReads = true
}
if strings.Contains(e, "http2xconnect=1") {
disableExtendedConnectProtocol = false
}
}
const (
@ -138,6 +149,10 @@ func (s Setting) Valid() error {
if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol)
}
case SettingEnableConnectProtocol:
if s.Val != 1 && s.Val != 0 {
return ConnectionError(ErrCodeProtocol)
}
}
return nil
}
@ -147,21 +162,25 @@ func (s Setting) Valid() error {
type SettingID uint16
const (
SettingHeaderTableSize SettingID = 0x1
SettingEnablePush SettingID = 0x2
SettingMaxConcurrentStreams SettingID = 0x3
SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6
SettingHeaderTableSize SettingID = 0x1
SettingEnablePush SettingID = 0x2
SettingMaxConcurrentStreams SettingID = 0x3
SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6
SettingEnableConnectProtocol SettingID = 0x8
SettingNoRFC7540Priorities SettingID = 0x9
)
var settingName = map[SettingID]string{
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
SettingEnablePush: "ENABLE_PUSH",
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
SettingEnablePush: "ENABLE_PUSH",
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
SettingNoRFC7540Priorities: "NO_RFC7540_PRIORITIES",
}
func (s SettingID) String() string {
@ -210,12 +229,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error)
}
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{}
@ -241,13 +254,17 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
_ incomparable
w io.Writer // immutable
bw *bufio.Writer // non-nil when data is buffered
_ incomparable
conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
}
func newBufferedWriter(w io.Writer) *bufferedWriter {
return &bufferedWriter{w: w}
func newBufferedWriter(conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{
conn: conn,
byteTimeout: timeout,
}
}
// bufWriterPoolBufferSize is the size of bufio.Writer's
@ -274,7 +291,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w)
bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw
}
return w.bw.Write(p)
@ -292,6 +309,32 @@ func (w *bufferedWriter) Flush() error {
return err
}
type bufferedWriterTimeoutWriter bufferedWriter
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.conn, w.byteTimeout, p)
}
// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
conn.SetWriteDeadline(time.Now().Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
// Either we finished the write, made no progress, or hit the deadline.
// Whichever it is, we're done now.
conn.SetWriteDeadline(time.Time{})
return n, err
}
}
}
func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
@ -362,23 +405,6 @@ func (s *sorter) SortStrings(ss []string) {
s.v = save
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// - a non-empty string starting with '/'
// - the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).

View file

@ -29,6 +29,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
@ -49,13 +50,18 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpcommon"
)
const (
prefaceTimeout = 10 * time.Second
firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
handlerChunkWriteSize = 4 << 10
defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
prefaceTimeout = 10 * time.Second
firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
handlerChunkWriteSize = 4 << 10
defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000
)
@ -127,6 +133,22 @@ type Server struct {
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using a ping
// frame will be carried out if no frame is received on the connection.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to a ping is not received.
// If zero, a default of 15 seconds is used.
PingTimeout time.Duration
// WriteByteTimeout is the timeout after which a connection will be
// closed if no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
// If zero or negative, there is no timeout.
WriteByteTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
@ -156,60 +178,13 @@ type Server struct {
state *serverInternalState
}
func (s *Server) initialConnRecvWindowSize() int32 {
if s.MaxUploadBufferPerConnection >= initialWindowSize {
return s.MaxUploadBufferPerConnection
}
return 1 << 20
}
func (s *Server) initialStreamRecvWindowSize() int32 {
if s.MaxUploadBufferPerStream > 0 {
return s.MaxUploadBufferPerStream
}
return 1 << 20
}
func (s *Server) maxReadFrameSize() uint32 {
if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
return v
}
return defaultMaxReadFrameSize
}
func (s *Server) maxConcurrentStreams() uint32 {
if v := s.MaxConcurrentStreams; v > 0 {
return v
}
return defaultMaxStreams
}
func (s *Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (s *Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
func (s *Server) maxQueuedControlFrames() int {
// TODO: if anybody asks, add a Server field, and remember to define the
// behavior of negative values.
return maxQueuedControlFrames
}
type serverInternalState struct {
mu sync.Mutex
activeConns map[*serverConn]struct{}
// Pool of error channels. This is per-Server rather than global
// because channels can't be reused across synctest bubbles.
errChanPool sync.Pool
}
func (s *serverInternalState) registerConn(sc *serverConn) {
@ -241,6 +216,27 @@ func (s *serverInternalState) startGracefulShutdown() {
s.mu.Unlock()
}
// Global error channel pool used for uninitialized Servers.
// We use a per-Server pool when possible to avoid using channels across synctest bubbles.
var errChanPool = sync.Pool{
New: func() any { return make(chan error, 1) },
}
func (s *serverInternalState) getErrChan() chan error {
if s == nil {
return errChanPool.Get().(chan error) // Server used without calling ConfigureServer
}
return s.errChanPool.Get().(chan error)
}
func (s *serverInternalState) putErrChan(ch chan error) {
if s == nil {
errChanPool.Put(ch) // Server used without calling ConfigureServer
return
}
s.errChanPool.Put(ch)
}
// ConfigureServer adds HTTP/2 support to a net/http Server.
//
// The configuration conf may be nil.
@ -253,7 +249,10 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if conf == nil {
conf = new(Server)
}
conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})}
conf.state = &serverInternalState{
activeConns: make(map[*serverConn]struct{}),
errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }},
}
if h1, h2 := s, conf; h2.IdleTimeout == 0 {
if h1.IdleTimeout != 0 {
h2.IdleTimeout = h1.IdleTimeout
@ -303,7 +302,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
}
protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) {
protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
if testHookOnConn != nil {
testHookOnConn()
}
@ -320,12 +319,31 @@ func ConfigureServer(s *http.Server, conf *Server) error {
ctx = bc.BaseContext()
}
conf.ServeConn(c, &ServeConnOpts{
Context: ctx,
Handler: h,
BaseConfig: hs,
Context: ctx,
Handler: h,
BaseConfig: hs,
SawClientPreface: sawClientPreface,
})
}
s.TLSNextProto[NextProtoTLS] = protoHandler
s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
protoHandler(hs, c, h, false)
}
// The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
//
// A connection passed in this method has already had the HTTP/2 preface read from it.
s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
nc, err := unencryptedNetConnFromTLSConn(c)
if err != nil {
if lg := hs.ErrorLog; lg != nil {
lg.Print(err)
} else {
log.Print(err)
}
go c.Close()
return
}
protoHandler(hs, nc, h, true)
}
return nil
}
@ -400,16 +418,25 @@ func (o *ServeConnOpts) handler() http.Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
if opts == nil {
opts = &ServeConnOpts{}
}
s.serveConn(c, opts, nil)
}
func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel()
http1srv := opts.baseConfig()
conf := configFromServer(http1srv, s)
sc := &serverConn{
srv: s,
hs: opts.baseConfig(),
hs: http1srv,
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c),
bw: newBufferedWriter(c, conf.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
@ -419,13 +446,19 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(),
advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize,
initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize,
pingTimeout: conf.PingTimeout,
countErrorFunc: conf.CountError,
serveG: newGoroutineLock(),
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
if newf != nil {
newf(sc)
}
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
@ -439,10 +472,13 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.conn.SetWriteDeadline(time.Time{})
}
if s.NewWriteScheduler != nil {
switch {
case s.NewWriteScheduler != nil:
sc.writeSched = s.NewWriteScheduler()
} else {
case clientPriorityDisabled(http1srv):
sc.writeSched = newRoundRobinWriteScheduler()
default:
sc.writeSched = newPriorityWriteSchedulerRFC9218()
}
// These start at the RFC-specified defaults. If there is a higher
@ -451,15 +487,15 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c)
if s.CountError != nil {
fr.countError = s.CountError
if conf.CountError != nil {
fr.countError = conf.CountError
}
fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize())
fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr
if tc, ok := c.(connectionStater); ok {
@ -492,7 +528,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// So for now, do nothing here again.
}
if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated."
@ -529,7 +565,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
opts.UpgradeRequest = nil
}
sc.serve()
sc.serve(conf)
}
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@ -569,6 +605,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string
writeSched WriteScheduler
countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve()
@ -588,6 +625,7 @@ type serverConn struct {
streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32
initialStreamRecvWindowSize int32
maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@ -598,9 +636,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
pingSent bool
sentPingData [8]byte
goAwayCode ErrCode
shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused
readIdleTimeout time.Duration
pingTimeout time.Duration
readIdleTimer *time.Timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@ -608,6 +651,23 @@ type serverConn struct {
// Used by startGracefulShutdown.
shutdownOnce sync.Once
// Used for RFC 9218 prioritization.
hasIntermediary bool // connection is done via an intermediary / proxy
priorityAware bool // the client has sent priority signal, meaning that it is aware of it.
}
func (sc *serverConn) writeSchedIgnoresRFC7540() bool {
switch sc.writeSched.(type) {
case *priorityWriteSchedulerRFC9218:
return true
case *randomWriteScheduler:
return true
case *roundRobinWriteScheduler:
return true
default:
return false
}
}
func (sc *serverConn) maxHeaderListSize() uint32 {
@ -615,11 +675,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 {
n = http.DefaultMaxHeaderBytes
}
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return uint32(n + typicalHeaders*perFieldOverhead)
return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
}
func (sc *serverConn) curOpenStreams() uint32 {
@ -775,8 +831,7 @@ const maxCachedCanonicalHeadersKeysSize = 2048
func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
buildCommonHeaderMapsOnce()
cv, ok := commonCanonHeader[v]
cv, ok := httpcommon.CachedCanonicalHeader(v)
if ok {
return cv
}
@ -811,8 +866,8 @@ type readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
gate := make(gate)
gateDone := gate.Done
gate := make(chan struct{})
gateDone := func() { gate <- struct{}{} }
for {
f, err := sc.framer.ReadFrame()
select {
@ -881,7 +936,7 @@ func (sc *serverConn) notePanic() {
}
}
func (sc *serverConn) serve() {
func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check()
defer sc.notePanic()
defer sc.conn.Close()
@ -893,20 +948,27 @@ func (sc *serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
settings := writeSettings{
{SettingMaxFrameSize, conf.MaxReadFrameSize},
{SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
{SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
}
if !disableExtendedConnectProtocol {
settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
}
if sc.writeSchedIgnoresRFC7540() {
settings = append(settings, Setting{SettingNoRFC7540Priorities, 1})
}
sc.writeFrame(FrameWriteRequest{
write: writeSettings{
{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
{SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
},
write: settings,
})
sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 {
if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
}
@ -926,11 +988,18 @@ func (sc *serverConn) serve() {
defer sc.idleTimer.Stop()
}
if conf.SendPingTimeout > 0 {
sc.readIdleTimeout = conf.SendPingTimeout
sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
defer sc.readIdleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
lastFrameTime := time.Now()
loopNum := 0
for {
loopNum++
@ -944,6 +1013,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
lastFrameTime = time.Now()
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
@ -975,6 +1045,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg:
sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo)
case readIdleTimerMsg:
sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
@ -997,7 +1069,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us
// run out of memory.
if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() {
if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection")
return
}
@ -1013,12 +1085,42 @@ func (sc *serverConn) serve() {
}
}
func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
if sc.pingSent {
sc.logf("timeout waiting for PING response")
if f := sc.countErrorFunc; f != nil {
f("conn_close_lost_ping")
}
sc.conn.Close()
return
}
pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
now := time.Now()
if pingAt.After(now) {
// We received frames since arming the ping timer.
// Reset it for the next possible timeout.
sc.readIdleTimer.Reset(pingAt.Sub(now))
return
}
sc.pingSent = true
// Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
// is we send a PING frame containing 0s.
_, _ = rand.Read(sc.sentPingData[:])
sc.writeFrame(FrameWriteRequest{
write: &writePing{data: sc.sentPingData},
})
sc.readIdleTimer.Reset(sc.pingTimeout)
}
type serverMessage int
// Message values sent to serveMsgCh.
var (
settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage)
readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
@ -1026,6 +1128,7 @@ var (
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) {
@ -1072,10 +1175,6 @@ func (sc *serverConn) readPreface() error {
}
}
var errChanPool = sync.Pool{
New: func() interface{} { return make(chan error, 1) },
}
var writeDataPool = sync.Pool{
New: func() interface{} { return new(writeData) },
}
@ -1083,7 +1182,7 @@ var writeDataPool = sync.Pool{
// writeDataFromHandler writes DATA response frames from a handler on
// the given stream.
func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
ch := errChanPool.Get().(chan error)
ch := sc.srv.state.getErrChan()
writeArg := writeDataPool.Get().(*writeData)
*writeArg = writeData{stream.id, data, endStream}
err := sc.writeFrameFromHandler(FrameWriteRequest{
@ -1115,7 +1214,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea
return errStreamClosed
}
}
errChanPool.Put(ch)
sc.srv.state.putErrChan(ch)
if frameWriteDone {
writeDataPool.Put(writeArg)
}
@ -1278,6 +1377,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false
sc.writingFrameAsync = false
if res.err != nil {
sc.conn.Close()
}
wr := res.wr
if writeEndsStream(wr.write) {
@ -1543,6 +1646,8 @@ func (sc *serverConn) processFrame(f Frame) error {
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
case *PriorityUpdateFrame:
return sc.processPriorityUpdate(f)
default:
sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil
@ -1552,6 +1657,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check()
if f.IsAck() {
if sc.pingSent && sc.sentPingData == f.Data {
// This is a response to a PING we sent.
sc.pingSent = false
sc.readIdleTimer.Reset(sc.readIdleTimeout)
}
// 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag."
return nil
@ -1639,7 +1749,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@ -1661,6 +1771,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
}
}
st.closeErr = err
st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id)
}
@ -1714,6 +1825,13 @@ func (sc *serverConn) processSetting(s Setting) error {
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val
case SettingEnableConnectProtocol:
// Receipt of this parameter by a server does not
// have any impact
case SettingNoRFC7540Priorities:
if s.Val > 1 {
return ConnectionError(ErrCodeProtocol)
}
default:
// Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST
@ -1984,13 +2102,33 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
if f.StreamEnded() {
initialState = stateHalfClosedRemote
}
st := sc.newStream(id, 0, initialState)
// We are handling two special cases here:
// 1. When a request is sent via an intermediary, we force priority to be
// u=3,i. This is essentially a round-robin behavior, and is done to ensure
// fairness between, for example, multiple clients using the same proxy.
// 2. Until a client has shown that it is aware of RFC 9218, we make its
// streams non-incremental by default. This is done to preserve the
// historical behavior of handling streams in a round-robin manner, rather
// than one-by-one to completion.
initialPriority := defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary)
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); ok && !sc.hasIntermediary {
headerPriority, priorityAware, hasIntermediary := f.rfc9218Priority(sc.priorityAware)
initialPriority = headerPriority
sc.hasIntermediary = hasIntermediary
if priorityAware {
sc.priorityAware = true
}
}
st := sc.newStream(id, 0, initialState, initialPriority)
if f.HasPriority() {
if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err
}
sc.writeSched.AdjustStream(st.id, f.Priority)
if !sc.writeSchedIgnoresRFC7540() {
sc.writeSched.AdjustStream(st.id, f.Priority)
}
}
rw, req, err := sc.newWriterAndRequest(st, f)
@ -2031,7 +2169,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
sc.serveG.check()
id := uint32(1)
sc.maxClientStreamID = id
st := sc.newStream(id, 0, stateHalfClosedRemote)
st := sc.newStream(id, 0, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
st.reqTrailer = req.Trailer
if st.reqTrailer != nil {
st.trailer = make(http.Header)
@ -2096,11 +2234,32 @@ func (sc *serverConn) processPriority(f *PriorityFrame) error {
if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err
}
// We need to avoid calling AdjustStream when using the RFC 9218 write
// scheduler. Otherwise, incremental's zero value in PriorityParam will
// unexpectedly make all streams non-incremental. This causes us to process
// streams one-by-one to completion rather than doing it in a round-robin
// manner (the historical behavior), which might be unexpected to users.
if sc.writeSchedIgnoresRFC7540() {
return nil
}
sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
return nil
}
func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream {
func (sc *serverConn) processPriorityUpdate(f *PriorityUpdateFrame) error {
sc.priorityAware = true
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); !ok {
return nil
}
p, ok := parseRFC9218Priority(f.Priority, sc.priorityAware)
if !ok {
return sc.countError("unparsable_priority_update", streamError(f.PrioritizedStreamID, ErrCodeProtocol))
}
sc.writeSched.AdjustStream(f.PrioritizedStreamID, p)
return nil
}
func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority PriorityParam) *stream {
sc.serveG.check()
if id == 0 {
panic("internal error: cannot create stream with id 0")
@ -2117,13 +2276,13 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID})
sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID, priority: priority})
if st.isPushed() {
sc.curPushedStreams++
} else {
@ -2139,19 +2298,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check()
rp := requestParam{
method: f.PseudoValue("method"),
scheme: f.PseudoValue("scheme"),
authority: f.PseudoValue("authority"),
path: f.PseudoValue("path"),
rp := httpcommon.ServerRequestParam{
Method: f.PseudoValue("method"),
Scheme: f.PseudoValue("scheme"),
Authority: f.PseudoValue("authority"),
Path: f.PseudoValue("path"),
Protocol: f.PseudoValue("protocol"),
}
isConnect := rp.method == "CONNECT"
// extended connect is disabled, so we should not see :protocol
if disableExtendedConnectProtocol && rp.Protocol != "" {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
isConnect := rp.Method == "CONNECT"
if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" {
if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
} else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
@ -2165,12 +2330,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
}
rp.header = make(http.Header)
header := make(http.Header)
rp.Header = header
for _, hf := range f.RegularFields() {
rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
header.Add(sc.canonicalHeader(hf.Name), hf.Value)
}
if rp.authority == "" {
rp.authority = rp.header.Get("Host")
if rp.Authority == "" {
rp.Authority = header.Get("Host")
}
if rp.Protocol != "" {
header.Set(":protocol", rp.Protocol)
}
rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
@ -2179,7 +2348,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
}
bodyOpen := !f.StreamEnded()
if bodyOpen {
if vv, ok := rp.header["Content-Length"]; ok {
if vv, ok := rp.Header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
req.ContentLength = int64(cl)
} else {
@ -2195,83 +2364,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return rw, req, nil
}
type requestParam struct {
method string
scheme, authority, path string
header http.Header
}
func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) {
sc.serveG.check()
var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" {
if rp.Scheme == "https" {
tlsState = sc.tlsState
}
needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
if needsContinue {
rp.header.Del("Expect")
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer http.Header
for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(textproto.TrimString(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(http.Header)
}
trailer[key] = nil
}
}
}
delete(rp.header, "Trailer")
var url_ *url.URL
var requestURI string
if rp.method == "CONNECT" {
url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.path)
if err != nil {
return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
}
requestURI = rp.path
res := httpcommon.NewServerRequest(rp)
if res.InvalidReason != "" {
return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
}
body := &requestBody{
conn: sc,
stream: st,
needsContinue: needsContinue,
needsContinue: res.NeedsContinue,
}
req := &http.Request{
Method: rp.method,
URL: url_,
req := (&http.Request{
Method: rp.Method,
URL: res.URL,
RemoteAddr: sc.remoteAddrStr,
Header: rp.header,
RequestURI: requestURI,
Header: rp.Header,
RequestURI: res.RequestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
Host: rp.authority,
Host: rp.Authority,
Body: body,
Trailer: trailer,
}
req = req.WithContext(st.ctx)
Trailer: res.Trailer,
}).WithContext(st.ctx)
rw := sc.newResponseWriter(st, req)
return rw, req, nil
}
@ -2391,7 +2515,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
// waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially
// mutates it.
errc = errChanPool.Get().(chan error)
errc = sc.srv.state.getErrChan()
}
if err := sc.writeFrameFromHandler(FrameWriteRequest{
write: headerData,
@ -2403,7 +2527,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
if errc != nil {
select {
case err := <-errc:
errChanPool.Put(errc)
sc.srv.state.putErrChan(errc)
return err
case <-sc.doneServing:
return errClientDisconnected
@ -2510,7 +2634,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) {
if err == io.EOF {
b.sawEOF = true
}
if b.conn == nil && inTests {
if b.conn == nil {
return
}
b.conn.noteBodyReadFromHandler(b.stream, n, err)
@ -2811,6 +2935,11 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
return nil
}
func (w *responseWriter) EnableFullDuplex() error {
// We always support full duplex responses, so this is a no-op.
return nil
}
func (w *responseWriter) Flush() {
w.FlushError()
}
@ -3079,7 +3208,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
method: opts.Method,
url: u,
header: cloneHeader(opts.Header),
done: errChanPool.Get().(chan error),
done: sc.srv.state.getErrChan(),
}
select {
@ -3096,7 +3225,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
case <-st.cw:
return errStreamClosed
case err := <-msg.done:
errChanPool.Put(msg.done)
sc.srv.state.putErrChan(msg.done)
return err
}
}
@ -3159,13 +3288,13 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
// transition to "half closed (remote)" after sending the initial HEADERS, but
// we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
method: msg.method,
scheme: msg.url.Scheme,
authority: msg.url.Host,
path: msg.url.RequestURI(),
header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
Method: msg.method,
Scheme: msg.url.Scheme,
Authority: msg.url.Host,
Path: msg.url.RequestURI(),
Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
})
if err != nil {
// Should not happen, since we've already validated msg.url.
@ -3257,7 +3386,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil {
return err
}
f := sc.srv.CountError
f := sc.countErrorFunc
if f == nil {
return err
}

View file

@ -1,331 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

File diff suppressed because it is too large Load diff

32
vendor/golang.org/x/net/http2/unencrypted.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"crypto/tls"
"errors"
"net"
)
const nextProtoUnencryptedHTTP2 = "unencrypted_http2"
// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn.
//
// TLSNextProto functions accept a *tls.Conn.
//
// When passing an unencrypted HTTP/2 connection to a TLSNextProto function,
// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection.
// To be extra careful about mistakes (accidentally dropping TLS encryption in a place
// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method
// that returns the actual connection we want to use.
func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) {
conner, ok := tc.NetConn().(interface {
UnencryptedNetConn() net.Conn
})
if !ok {
return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff")
}
return conner.UnencryptedNetConn(), nil
}

View file

@ -13,6 +13,7 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpcommon"
)
// writeFramer is implemented by any type that is used to write frames.
@ -131,6 +132,16 @@ func (se StreamError) writeFrame(ctx writeContext) error {
func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max }
type writePing struct {
data [8]byte
}
func (w writePing) writeFrame(ctx writeContext) error {
return ctx.Framer().WritePing(false, w.data)
}
func (w writePing) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.data) <= max }
type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error {
@ -341,7 +352,7 @@ func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
}
for _, k := range keys {
vv := h[k]
k, ascii := lowerHeader(k)
k, ascii := httpcommon.LowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).

View file

@ -42,6 +42,8 @@ type OpenStreamOptions struct {
// PusherID is zero if the stream was initiated by the client. Otherwise,
// PusherID names the stream that pushed the newly opened stream.
PusherID uint32
// priority is used to set the priority of the newly opened stream.
priority PriorityParam
}
// FrameWriteRequest is a request to write a frame.
@ -183,45 +185,75 @@ func (wr *FrameWriteRequest) replyToWriter(err error) {
}
// writeQueue is used by implementations of WriteScheduler.
//
// Each writeQueue contains a queue of FrameWriteRequests, meant to store all
// FrameWriteRequests associated with a given stream. This is implemented as a
// two-stage queue: currQueue[currPos:] and nextQueue. Removing an item is done
// by incrementing currPos of currQueue. Adding an item is done by appending it
// to the nextQueue. If currQueue is empty when trying to remove an item, we
// can swap currQueue and nextQueue to remedy the situation.
// This two-stage queue is analogous to the use of two lists in Okasaki's
// purely functional queue but without the overhead of reversing the list when
// swapping stages.
//
// writeQueue also contains prev and next, this can be used by implementations
// of WriteScheduler to construct data structures that represent the order of
// writing between different streams (e.g. circular linked list).
type writeQueue struct {
s []FrameWriteRequest
currQueue []FrameWriteRequest
nextQueue []FrameWriteRequest
currPos int
prev, next *writeQueue
}
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
func (q *writeQueue) empty() bool {
return (len(q.currQueue) - q.currPos + len(q.nextQueue)) == 0
}
func (q *writeQueue) push(wr FrameWriteRequest) {
q.s = append(q.s, wr)
q.nextQueue = append(q.nextQueue, wr)
}
func (q *writeQueue) shift() FrameWriteRequest {
if len(q.s) == 0 {
if q.empty() {
panic("invalid use of queue")
}
wr := q.s[0]
// TODO: less copy-happy queue.
copy(q.s, q.s[1:])
q.s[len(q.s)-1] = FrameWriteRequest{}
q.s = q.s[:len(q.s)-1]
if q.currPos >= len(q.currQueue) {
q.currQueue, q.currPos, q.nextQueue = q.nextQueue, 0, q.currQueue[:0]
}
wr := q.currQueue[q.currPos]
q.currQueue[q.currPos] = FrameWriteRequest{}
q.currPos++
return wr
}
func (q *writeQueue) peek() *FrameWriteRequest {
if q.currPos < len(q.currQueue) {
return &q.currQueue[q.currPos]
}
if len(q.nextQueue) > 0 {
return &q.nextQueue[0]
}
return nil
}
// consume consumes up to n bytes from q.s[0]. If the frame is
// entirely consumed, it is removed from the queue. If the frame
// is partially consumed, the frame is kept with the consumed
// bytes removed. Returns true iff any bytes were consumed.
func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) {
if len(q.s) == 0 {
if q.empty() {
return FrameWriteRequest{}, false
}
consumed, rest, numresult := q.s[0].Consume(n)
consumed, rest, numresult := q.peek().Consume(n)
switch numresult {
case 0:
return FrameWriteRequest{}, false
case 1:
q.shift()
case 2:
q.s[0] = rest
*q.peek() = rest
}
return consumed, true
}
@ -230,10 +262,15 @@ type writeQueuePool []*writeQueue
// put inserts an unused writeQueue into the pool.
func (p *writeQueuePool) put(q *writeQueue) {
for i := range q.s {
q.s[i] = FrameWriteRequest{}
for i := range q.currQueue {
q.currQueue[i] = FrameWriteRequest{}
}
q.s = q.s[:0]
for i := range q.nextQueue {
q.nextQueue[i] = FrameWriteRequest{}
}
q.currQueue = q.currQueue[:0]
q.nextQueue = q.nextQueue[:0]
q.currPos = 0
*p = append(*p, q)
}

View file

@ -11,7 +11,7 @@ import (
)
// RFC 7540, Section 5.3.5: the default weight is 16.
const priorityDefaultWeight = 15 // 16 = 15 + 1
const priorityDefaultWeightRFC7540 = 15 // 16 = 15 + 1
// PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
type PriorityWriteSchedulerConfig struct {
@ -56,6 +56,10 @@ type PriorityWriteSchedulerConfig struct {
// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3.
// If cfg is nil, default options are used.
func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
return newPriorityWriteSchedulerRFC7540(cfg)
}
func newPriorityWriteSchedulerRFC7540(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
if cfg == nil {
// For justification of these defaults, see:
// https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY
@ -66,8 +70,8 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
}
}
ws := &priorityWriteScheduler{
nodes: make(map[uint32]*priorityNode),
ws := &priorityWriteSchedulerRFC7540{
nodes: make(map[uint32]*priorityNodeRFC7540),
maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
@ -81,32 +85,32 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
return ws
}
type priorityNodeState int
type priorityNodeStateRFC7540 int
const (
priorityNodeOpen priorityNodeState = iota
priorityNodeClosed
priorityNodeIdle
priorityNodeOpenRFC7540 priorityNodeStateRFC7540 = iota
priorityNodeClosedRFC7540
priorityNodeIdleRFC7540
)
// priorityNode is a node in an HTTP/2 priority tree.
// priorityNodeRFC7540 is a node in an HTTP/2 priority tree.
// Each node is associated with a single stream ID.
// See RFC 7540, Section 5.3.
type priorityNode struct {
q writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state priorityNodeState // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
type priorityNodeRFC7540 struct {
q writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state priorityNodeStateRFC7540 // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
// These links form the priority tree.
parent *priorityNode
kids *priorityNode // start of the kids list
prev, next *priorityNode // doubly-linked list of siblings
parent *priorityNodeRFC7540
kids *priorityNodeRFC7540 // start of the kids list
prev, next *priorityNodeRFC7540 // doubly-linked list of siblings
}
func (n *priorityNode) setParent(parent *priorityNode) {
func (n *priorityNodeRFC7540) setParent(parent *priorityNodeRFC7540) {
if n == parent {
panic("setParent to self")
}
@ -141,7 +145,7 @@ func (n *priorityNode) setParent(parent *priorityNode) {
}
}
func (n *priorityNode) addBytes(b int64) {
func (n *priorityNodeRFC7540) addBytes(b int64) {
n.bytes += b
for ; n != nil; n = n.parent {
n.subtreeBytes += b
@ -154,7 +158,7 @@ func (n *priorityNode) addBytes(b int64) {
//
// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
// if any ancestor p of n is still open (ignoring the root node).
func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool {
func (n *priorityNodeRFC7540) walkReadyInOrder(openParent bool, tmp *[]*priorityNodeRFC7540, f func(*priorityNodeRFC7540, bool) bool) bool {
if !n.q.empty() && f(n, openParent) {
return true
}
@ -165,7 +169,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
// Don't consider the root "open" when updating openParent since
// we can't send data frames on the root stream (only control frames).
if n.id != 0 {
openParent = openParent || (n.state == priorityNodeOpen)
openParent = openParent || (n.state == priorityNodeOpenRFC7540)
}
// Common case: only one kid or all kids have the same weight.
@ -195,7 +199,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
*tmp = append(*tmp, n.kids)
n.kids.setParent(nil)
}
sort.Sort(sortPriorityNodeSiblings(*tmp))
sort.Sort(sortPriorityNodeSiblingsRFC7540(*tmp))
for i := len(*tmp) - 1; i >= 0; i-- {
(*tmp)[i].setParent(n) // setParent inserts at the head of n.kids
}
@ -207,15 +211,15 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
return false
}
type sortPriorityNodeSiblings []*priorityNode
type sortPriorityNodeSiblingsRFC7540 []*priorityNodeRFC7540
func (z sortPriorityNodeSiblings) Len() int { return len(z) }
func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z sortPriorityNodeSiblings) Less(i, k int) bool {
func (z sortPriorityNodeSiblingsRFC7540) Len() int { return len(z) }
func (z sortPriorityNodeSiblingsRFC7540) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z sortPriorityNodeSiblingsRFC7540) Less(i, k int) bool {
// Prefer the subtree that has sent fewer bytes relative to its weight.
// See sections 5.3.2 and 5.3.4.
wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
wi, bi := float64(z[i].weight)+1, float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight)+1, float64(z[k].subtreeBytes)
if bi == 0 && bk == 0 {
return wi >= wk
}
@ -225,13 +229,13 @@ func (z sortPriorityNodeSiblings) Less(i, k int) bool {
return bi/bk <= wi/wk
}
type priorityWriteScheduler struct {
type priorityWriteSchedulerRFC7540 struct {
// root is the root of the priority tree, where root.id = 0.
// The root queues control frames that are not associated with any stream.
root priorityNode
root priorityNodeRFC7540
// nodes maps stream ids to priority tree nodes.
nodes map[uint32]*priorityNode
nodes map[uint32]*priorityNodeRFC7540
// maxID is the maximum stream id in nodes.
maxID uint32
@ -239,7 +243,7 @@ type priorityWriteScheduler struct {
// lists of nodes that have been closed or are idle, but are kept in
// the tree for improved prioritization. When the lengths exceed either
// maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
closedNodes, idleNodes []*priorityNode
closedNodes, idleNodes []*priorityNodeRFC7540
// From the config.
maxClosedNodesInTree int
@ -248,19 +252,19 @@ type priorityWriteScheduler struct {
enableWriteThrottle bool
// tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
tmp []*priorityNode
tmp []*priorityNodeRFC7540
// pool of empty queues for reuse.
queuePool writeQueuePool
}
func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) {
func (ws *priorityWriteSchedulerRFC7540) OpenStream(streamID uint32, options OpenStreamOptions) {
// The stream may be currently idle but cannot be opened or closed.
if curr := ws.nodes[streamID]; curr != nil {
if curr.state != priorityNodeIdle {
if curr.state != priorityNodeIdleRFC7540 {
panic(fmt.Sprintf("stream %d already opened", streamID))
}
curr.state = priorityNodeOpen
curr.state = priorityNodeOpenRFC7540
return
}
@ -272,11 +276,11 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
if parent == nil {
parent = &ws.root
}
n := &priorityNode{
n := &priorityNodeRFC7540{
q: *ws.queuePool.get(),
id: streamID,
weight: priorityDefaultWeight,
state: priorityNodeOpen,
weight: priorityDefaultWeightRFC7540,
state: priorityNodeOpenRFC7540,
}
n.setParent(parent)
ws.nodes[streamID] = n
@ -285,24 +289,23 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
}
}
func (ws *priorityWriteScheduler) CloseStream(streamID uint32) {
func (ws *priorityWriteSchedulerRFC7540) CloseStream(streamID uint32) {
if streamID == 0 {
panic("violation of WriteScheduler interface: cannot close stream 0")
}
if ws.nodes[streamID] == nil {
panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
}
if ws.nodes[streamID].state != priorityNodeOpen {
if ws.nodes[streamID].state != priorityNodeOpenRFC7540 {
panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
}
n := ws.nodes[streamID]
n.state = priorityNodeClosed
n.state = priorityNodeClosedRFC7540
n.addBytes(-n.bytes)
q := n.q
ws.queuePool.put(&q)
n.q.s = nil
if ws.maxClosedNodesInTree > 0 {
ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
} else {
@ -310,7 +313,7 @@ func (ws *priorityWriteScheduler) CloseStream(streamID uint32) {
}
}
func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) {
func (ws *priorityWriteSchedulerRFC7540) AdjustStream(streamID uint32, priority PriorityParam) {
if streamID == 0 {
panic("adjustPriority on root")
}
@ -324,11 +327,11 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
return
}
ws.maxID = streamID
n = &priorityNode{
n = &priorityNodeRFC7540{
q: *ws.queuePool.get(),
id: streamID,
weight: priorityDefaultWeight,
state: priorityNodeIdle,
weight: priorityDefaultWeightRFC7540,
state: priorityNodeIdleRFC7540,
}
n.setParent(&ws.root)
ws.nodes[streamID] = n
@ -340,7 +343,7 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
parent := ws.nodes[priority.StreamDep]
if parent == nil {
n.setParent(&ws.root)
n.weight = priorityDefaultWeight
n.weight = priorityDefaultWeightRFC7540
return
}
@ -381,8 +384,8 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
n.weight = priority.Weight
}
func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) {
var n *priorityNode
func (ws *priorityWriteSchedulerRFC7540) Push(wr FrameWriteRequest) {
var n *priorityNodeRFC7540
if wr.isControl() {
n = &ws.root
} else {
@ -401,8 +404,8 @@ func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) {
n.q.push(wr)
}
func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool {
func (ws *priorityWriteSchedulerRFC7540) Pop() (wr FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNodeRFC7540, openParent bool) bool {
limit := int32(math.MaxInt32)
if openParent {
limit = ws.writeThrottleLimit
@ -428,7 +431,7 @@ func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) {
return wr, ok
}
func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) {
func (ws *priorityWriteSchedulerRFC7540) addClosedOrIdleNode(list *[]*priorityNodeRFC7540, maxSize int, n *priorityNodeRFC7540) {
if maxSize == 0 {
return
}
@ -442,9 +445,9 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
*list = append(*list, n)
}
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) {
for k := n.kids; k != nil; k = k.next {
k.setParent(n.parent)
func (ws *priorityWriteSchedulerRFC7540) removeNode(n *priorityNodeRFC7540) {
for n.kids != nil {
n.kids.setParent(n.parent)
}
n.setParent(nil)
delete(ws.nodes, n.id)

View file

@ -0,0 +1,224 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"fmt"
"math"
)
type streamMetadata struct {
location *writeQueue
priority PriorityParam
}
type priorityWriteSchedulerRFC9218 struct {
// control contains control frames (SETTINGS, PING, etc.).
control writeQueue
// heads contain the head of a circular list of streams.
// We put these heads within a nested array that represents urgency and
// incremental, as defined in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-priority-parameters.
// 8 represents u=0 up to u=7, and 2 represents i=false and i=true.
heads [8][2]*writeQueue
// streams contains a mapping between each stream ID and their metadata, so
// we can quickly locate them when needing to, for example, adjust their
// priority.
streams map[uint32]streamMetadata
// queuePool are empty queues for reuse.
queuePool writeQueuePool
// prioritizeIncremental is used to determine whether we should prioritize
// incremental streams or not, when urgency is the same in a given Pop()
// call.
prioritizeIncremental bool
// priorityUpdateBuf is used to buffer the most recent PRIORITY_UPDATE we
// receive per https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
priorityUpdateBuf struct {
// streamID being 0 means that the buffer is empty. This is a safe
// assumption as PRIORITY_UPDATE for stream 0 is a PROTOCOL_ERROR.
streamID uint32
priority PriorityParam
}
}
func newPriorityWriteSchedulerRFC9218() WriteScheduler {
ws := &priorityWriteSchedulerRFC9218{
streams: make(map[uint32]streamMetadata),
}
return ws
}
func (ws *priorityWriteSchedulerRFC9218) OpenStream(streamID uint32, opt OpenStreamOptions) {
if ws.streams[streamID].location != nil {
panic(fmt.Errorf("stream %d already opened", streamID))
}
if streamID == ws.priorityUpdateBuf.streamID {
ws.priorityUpdateBuf.streamID = 0
opt.priority = ws.priorityUpdateBuf.priority
}
q := ws.queuePool.get()
ws.streams[streamID] = streamMetadata{
location: q,
priority: opt.priority,
}
u, i := opt.priority.urgency, opt.priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
}
func (ws *priorityWriteSchedulerRFC9218) CloseStream(streamID uint32) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
return
}
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
delete(ws.streams, streamID)
ws.queuePool.put(q)
}
func (ws *priorityWriteSchedulerRFC9218) AdjustStream(streamID uint32, priority PriorityParam) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
ws.priorityUpdateBuf.streamID = streamID
ws.priorityUpdateBuf.priority = priority
return
}
// Remove stream from current location.
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
// Insert stream to the new queue.
u, i = priority.urgency, priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
// Update the metadata.
ws.streams[streamID] = streamMetadata{
location: q,
priority: priority,
}
}
func (ws *priorityWriteSchedulerRFC9218) Push(wr FrameWriteRequest) {
if wr.isControl() {
ws.control.push(wr)
return
}
q := ws.streams[wr.StreamID()].location
if q == nil {
// This is a closed stream.
// wr should not be a HEADERS or DATA frame.
// We push the request onto the control queue.
if wr.DataSize() > 0 {
panic("add DATA on non-open stream")
}
ws.control.push(wr)
return
}
q.push(wr)
}
func (ws *priorityWriteSchedulerRFC9218) Pop() (FrameWriteRequest, bool) {
// Control and RST_STREAM frames first.
if !ws.control.empty() {
return ws.control.shift(), true
}
// On the next Pop(), we want to prioritize incremental if we prioritized
// non-incremental request of the same urgency this time. Vice-versa.
// i.e. when there are incremental and non-incremental requests at the same
// priority, we give 50% of our bandwidth to the incremental ones in
// aggregate and 50% to the first non-incremental one (since
// non-incremental streams do not use round-robin writes).
ws.prioritizeIncremental = !ws.prioritizeIncremental
// Always prioritize lowest u (i.e. highest urgency level).
for u := range ws.heads {
for i := range ws.heads[u] {
// When we want to prioritize incremental, we try to pop i=true
// first before i=false when u is the same.
if ws.prioritizeIncremental {
i = (i + 1) % 2
}
q := ws.heads[u][i]
if q == nil {
continue
}
for {
if wr, ok := q.consume(math.MaxInt32); ok {
if i == 1 {
// For incremental streams, we update head to q.next so
// we can round-robin between multiple streams that can
// immediately benefit from partial writes.
ws.heads[u][i] = q.next
} else {
// For non-incremental streams, we try to finish one to
// completion rather than doing round-robin. However,
// we update head here so that if q.consume() is !ok
// (e.g. the stream has no more frame to consume), head
// is updated to the next q that has frames to consume
// on future iterations. This way, we do not prioritize
// writing to unavailable stream on next Pop() calls,
// preventing head-of-line blocking.
ws.heads[u][i] = q
}
return wr, true
}
q = q.next
if q == ws.heads[u][i] {
break
}
}
}
}
return FrameWriteRequest{}, false
}

View file

@ -25,7 +25,7 @@ type roundRobinWriteScheduler struct {
}
// newRoundRobinWriteScheduler constructs a new write scheduler.
// The round robin scheduler priorizes control frames
// The round robin scheduler prioritizes control frames
// like SETTINGS and PING over DATA frames.
// When there are no control frames to send, it performs a round-robin
// selection from the ready streams.

53
vendor/golang.org/x/net/internal/httpcommon/ascii.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpcommon
import "strings"
// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
// contains helper functions which may use Unicode-aware functions which would
// otherwise be unsafe and could introduce vulnerabilities if used improperly.
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func asciiEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lower(s[i]) != lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// isASCIIPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func isASCIIPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// asciiToLower returns the lowercase version of s if s is ASCII and printable,
// and whether or not it was.
func asciiToLower(s string) (lower string, ok bool) {
if !isASCIIPrint(s) {
return "", false
}
return strings.ToLower(s), true
}

View file

@ -1,11 +1,11 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
package httpcommon
import (
"net/http"
"net/textproto"
"sync"
)
@ -82,13 +82,15 @@ func buildCommonHeaderMaps() {
commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common))
for _, v := range common {
chk := http.CanonicalHeaderKey(v)
chk := textproto.CanonicalMIMEHeaderKey(v)
commonLowerHeader[chk] = v
commonCanonHeader[v] = chk
}
}
func lowerHeader(v string) (lower string, ascii bool) {
// LowerHeader returns the lowercase form of a header name,
// used on the wire for HTTP/2 and HTTP/3 requests.
func LowerHeader(v string) (lower string, ascii bool) {
buildCommonHeaderMapsOnce()
if s, ok := commonLowerHeader[v]; ok {
return s, true
@ -96,10 +98,18 @@ func lowerHeader(v string) (lower string, ascii bool) {
return asciiToLower(v)
}
func canonicalHeader(v string) string {
// CanonicalHeader canonicalizes a header name. (For example, "host" becomes "Host".)
func CanonicalHeader(v string) string {
buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok {
return s
}
return http.CanonicalHeaderKey(v)
return textproto.CanonicalMIMEHeaderKey(v)
}
// CachedCanonicalHeader returns the canonical form of a well-known header name.
func CachedCanonicalHeader(v string) (string, bool) {
buildCommonHeaderMapsOnce()
s, ok := commonCanonHeader[v]
return s, ok
}

467
vendor/golang.org/x/net/internal/httpcommon/request.go generated vendored Normal file
View file

@ -0,0 +1,467 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpcommon
import (
"context"
"errors"
"fmt"
"net/http/httptrace"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
)
var (
ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
)
// Request is a subset of http.Request.
// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http
// without creating a dependency cycle.
type Request struct {
URL *url.URL
Method string
Host string
Header map[string][]string
Trailer map[string][]string
ActualContentLength int64 // 0 means 0, -1 means unknown
}
// EncodeHeadersParam is parameters to EncodeHeaders.
type EncodeHeadersParam struct {
Request Request
// AddGzipHeader indicates that an "accept-encoding: gzip" header should be
// added to the request.
AddGzipHeader bool
// PeerMaxHeaderListSize, when non-zero, is the peer's MAX_HEADER_LIST_SIZE setting.
PeerMaxHeaderListSize uint64
// DefaultUserAgent is the User-Agent header to send when the request
// neither contains a User-Agent nor disables it.
DefaultUserAgent string
}
// EncodeHeadersResult is the result of EncodeHeaders.
type EncodeHeadersResult struct {
HasBody bool
HasTrailers bool
}
// EncodeHeaders constructs request headers common to HTTP/2 and HTTP/3.
// It validates a request and calls headerf with each pseudo-header and header
// for the request.
// The headerf function is called with the validated, canonicalized header name.
func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
req := param.Request
// Check for invalid connection-level headers.
if err := checkConnHeaders(req.Header); err != nil {
return res, err
}
if req.URL == nil {
return res, errors.New("Request.URL is nil")
}
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return res, err
}
if !httpguts.ValidHostHeader(host) {
return res, errors.New("invalid Host header")
}
// isNormalConnect is true if this is a non-extended CONNECT request.
isNormalConnect := false
var protocol string
if vv := req.Header[":protocol"]; len(vv) > 0 {
protocol = vv[0]
}
if req.Method == "CONNECT" && protocol == "" {
isNormalConnect = true
} else if protocol != "" && req.Method != "CONNECT" {
return res, errors.New("invalid :protocol header in non-CONNECT request")
}
// Validate the path, except for non-extended CONNECT requests which have no path.
var path string
if !isNormalConnect {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return res, fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
if err := validateHeaders(req.Header); err != "" {
return res, fmt.Errorf("invalid HTTP header %s", err)
}
if err := validateHeaders(req.Trailer); err != "" {
return res, fmt.Errorf("invalid HTTP trailer %s", err)
}
trailers, err := commaSeparatedTrailers(req.Trailer)
if err != nil {
return res, err
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production, see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
m := req.Method
if m == "" {
m = "GET"
}
f(":method", m)
if !isNormalConnect {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if protocol != "" {
f(":protocol", protocol)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if asciiEqualFold(k, "connection") ||
asciiEqualFold(k, "proxy-connection") ||
asciiEqualFold(k, "transfer-encoding") ||
asciiEqualFold(k, "upgrade") ||
asciiEqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if asciiEqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
} else if asciiEqualFold(k, "cookie") {
// Per 8.1.2.5 To allow for better compression efficiency, the
// Cookie header field MAY be split into separate header fields,
// each with one or more cookie-pairs.
for _, v := range vv {
for {
p := strings.IndexByte(v, ';')
if p < 0 {
break
}
f("cookie", v[:p])
p++
// strip space after semicolon if any.
for p+1 <= len(v) && v[p] == ' ' {
p++
}
v = v[p:]
}
if len(v) > 0 {
f("cookie", v)
}
}
continue
} else if k == ":protocol" {
// :protocol pseudo-header was already sent above.
continue
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
}
if param.AddGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", param.DefaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
if param.PeerMaxHeaderListSize > 0 {
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
if hlSize > param.PeerMaxHeaderListSize {
return res, ErrRequestHeaderListSize
}
}
trace := httptrace.ContextClientTrace(ctx)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name, ascii := LowerHeader(name)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
return
}
headerf(name, value)
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(name, []string{value})
}
})
res.HasBody = req.ActualContentLength != 0
res.HasTrailers = trailers != ""
return res, nil
}
// IsRequestGzip reports whether we should add an Accept-Encoding: gzip header
// for a request.
func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !disableCompression &&
len(header["Accept-Encoding"]) == 0 &&
len(header["Range"]) == 0 &&
method != "HEAD" {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// http://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
return true
}
return false
}
// checkConnHeaders checks whether req has any invalid connection-level headers.
//
// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2-3
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.2.2-1
//
// Certain headers are special-cased as okay but not transmitted later.
// For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding.
func checkConnHeaders(h map[string][]string) error {
if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("invalid Upgrade request header: %q", vv)
}
if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
}
if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("invalid Connection request header: %q", vv)
}
return nil
}
func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
keys := make([]string, 0, len(trailer))
for k := range trailer {
k = CanonicalHeader(k)
switch k {
case "Transfer-Encoding", "Trailer", "Content-Length":
return "", fmt.Errorf("invalid Trailer key %q", k)
}
keys = append(keys, k)
}
if len(keys) > 0 {
sort.Strings(keys)
return strings.Join(keys, ","), nil
}
return "", nil
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// - a non-empty string starting with '/'
// - the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
func validateHeaders(hdrs map[string][]string) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
// shouldSendReqContentLength reports whether we should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// ServerRequestParam is parameters to NewServerRequest.
type ServerRequestParam struct {
Method string
Scheme, Authority, Path string
Protocol string
Header map[string][]string
}
// ServerRequestResult is the result of NewServerRequest.
type ServerRequestResult struct {
// Various http.Request fields.
URL *url.URL
RequestURI string
Trailer map[string][]string
NeedsContinue bool // client provided an "Expect: 100-continue" header
// If the request should be rejected, this is a short string suitable for passing
// to the http2 package's CountError function.
// It might be a bit odd to return errors this way rather than returning an error,
// but this ensures we don't forget to include a CountError reason.
InvalidReason string
}
func NewServerRequest(rp ServerRequestParam) ServerRequestResult {
needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue")
if needsContinue {
delete(rp.Header, "Expect")
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.Header["Cookie"]; len(cookies) > 1 {
rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")}
}
// Setup Trailers
var trailer map[string][]string
for _, v := range rp.Header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(map[string][]string)
}
trailer[key] = nil
}
}
}
delete(rp.Header, "Trailer")
// "':authority' MUST NOT include the deprecated userinfo subcomponent
// for "http" or "https" schemed URIs."
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") {
return ServerRequestResult{
InvalidReason: "userinfo_in_authority",
}
}
var url_ *url.URL
var requestURI string
if rp.Method == "CONNECT" && rp.Protocol == "" {
url_ = &url.URL{Host: rp.Authority}
requestURI = rp.Authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.Path)
if err != nil {
return ServerRequestResult{
InvalidReason: "bad_path",
}
}
requestURI = rp.Path
}
return ServerRequestResult{
URL: url_,
NeedsContinue: needsContinue,
RequestURI: requestURI,
Trailer: trailer,
}
}

665
vendor/golang.org/x/net/internal/httpsfv/httpsfv.go generated vendored Normal file
View file

@ -0,0 +1,665 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpsfv provides functionality for dealing with HTTP Structured
// Field Values.
package httpsfv
import (
"slices"
"strconv"
"strings"
"time"
"unicode/utf8"
)
func isLCAlpha(b byte) bool {
return (b >= 'a' && b <= 'z')
}
func isAlpha(b byte) bool {
return isLCAlpha(b) || (b >= 'A' && b <= 'Z')
}
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
func isVChar(b byte) bool {
return b >= 0x21 && b <= 0x7e
}
func isSP(b byte) bool {
return b == 0x20
}
func isTChar(b byte) bool {
if isAlpha(b) || isDigit(b) {
return true
}
return slices.Contains([]byte{'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~'}, b)
}
func countLeftWhitespace(s string) int {
i := 0
for _, ch := range []byte(s) {
if ch != ' ' && ch != '\t' {
break
}
i++
}
return i
}
// https://www.rfc-editor.org/rfc/rfc4648#section-8.
func decOctetHex(ch1, ch2 byte) (ch byte, ok bool) {
decBase16 := func(in byte) (out byte, ok bool) {
if !isDigit(in) && !(in >= 'a' && in <= 'f') {
return 0, false
}
if isDigit(in) {
return in - '0', true
}
return in - 'a' + 10, true
}
if ch1, ok = decBase16(ch1); !ok {
return 0, ok
}
if ch2, ok = decBase16(ch2); !ok {
return 0, ok
}
return ch1<<4 | ch2, true
}
// ParseList parses a list from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a list, it will call the given
// function using each of the members and parameters contained in the list.
// This allows the caller to extract information out of the list.
//
// This function will return once it encounters the end of the string, or
// something that is not a list. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-list.
func ParseList(s string, f func(member, param string)) (ok bool) {
for len(s) != 0 {
var member, param string
if len(s) != 0 && s[0] == '(' {
if member, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if member, s, ok = consumeBareItem(s); !ok {
return ok
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(member, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] != ',' {
return false
}
s = s[1:]
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// consumeBareInnerList consumes an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// except for the inner list's top-most parameter.
// For example, given `(a;b c;d);e`, it will consume only `(a;b c;d)`.
func consumeBareInnerList(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '(' {
return "", s, false
}
rest = s[1:]
for len(rest) != 0 {
var bareItem, param string
rest = rest[countLeftWhitespace(rest):]
if len(rest) != 0 && rest[0] == ')' {
rest = rest[1:]
break
}
if bareItem, rest, ok = consumeBareItem(rest); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if len(rest) == 0 || (rest[0] != ')' && !isSP(rest[0])) {
return "", s, false
}
if f != nil {
f(bareItem, param)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseBareInnerList parses a bare inner list from a given HTTP Structured
// Field Values.
//
// We define a bare inner list as an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// without the top-most parameter of the inner list. For example, given the
// inner list `(a;b c;d);e`, the bare inner list would be `(a;b c;d)`.
//
// Given an HTTP SFV string that represents a bare inner list, it will call the
// given function using each of the bare item and parameter within the bare
// inner list. This allows the caller to extract information out of the bare
// inner list.
//
// This function will return once it encounters the end of the bare inner list,
// or something that is not a bare inner list. If it cannot consume the entire
// given string, the ok value returned will be false.
func ParseBareInnerList(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeBareInnerList(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func consumeItem(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
var bareItem, param string
if bareItem, rest, ok = consumeBareItem(s); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if f != nil {
f(bareItem, param)
}
return s[:len(s)-len(rest)], rest, true
}
// ParseItem parses an item from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents an item, it will call the given
// function once, with the bare item and the parameter of the item. This allows
// the caller to extract information out of the item.
//
// This function will return once it encounters the end of the string, or
// something that is not an item. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func ParseItem(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeItem(s, f)
return rest == "" && ok
}
// ParseDictionary parses a dictionary from a given HTTP Structured Field
// Values.
//
// Given an HTTP SFV string that represents a dictionary, it will call the
// given function using each of the keys, values, and parameters contained in
// the dictionary. This allows the caller to extract information out of the
// dictionary.
//
// This function will return once it encounters the end of the string, or
// something that is not a dictionary. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-dictionary.
func ParseDictionary(s string, f func(key, val, param string)) (ok bool) {
for len(s) != 0 {
var key, val, param string
val = "?1" // Default value for empty val is boolean true.
if key, s, ok = consumeKey(s); !ok {
return ok
}
if len(s) != 0 && s[0] == '=' {
s = s[1:]
if len(s) != 0 && s[0] == '(' {
if val, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if val, s, ok = consumeBareItem(s); !ok {
return ok
}
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(key, val, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] == ',' {
s = s[1:]
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func consumeParameter(s string, f func(key, val string)) (consumed, rest string, ok bool) {
rest = s
for len(rest) != 0 {
var key, val string
val = "?1" // Default value for empty val is boolean true.
if rest[0] != ';' {
break
}
rest = rest[1:]
rest = rest[countLeftWhitespace(rest):]
key, rest, ok = consumeKey(rest)
if !ok {
return "", s, ok
}
if len(rest) != 0 && rest[0] == '=' {
rest = rest[1:]
val, rest, ok = consumeBareItem(rest)
if !ok {
return "", s, ok
}
}
if f != nil {
f(key, val)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseParameter parses a parameter from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a parameter, it will call the given
// function using each of the keys and values contained in the parameter. This
// allows the caller to extract information out of the parameter.
//
// This function will return once it encounters the end of the string, or
// something that is not a parameter. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func ParseParameter(s string, f func(key, val string)) (ok bool) {
_, rest, ok := consumeParameter(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-key.
func consumeKey(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isLCAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isLCAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("_-.*"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func consumeIntegerOrDecimal(s string) (consumed, rest string, ok bool) {
var i, signOffset, periodIndex int
var isDecimal bool
if i < len(s) && s[i] == '-' {
i++
signOffset++
}
if i >= len(s) {
return "", s, false
}
if !isDigit(s[i]) {
return "", s, false
}
for i < len(s) {
ch := s[i]
if isDigit(ch) {
i++
continue
}
if !isDecimal && ch == '.' {
if i-signOffset > 12 {
return "", s, false
}
periodIndex = i
isDecimal = true
i++
continue
}
break
}
if !isDecimal && i-signOffset > 15 {
return "", s, false
}
if isDecimal {
if i-signOffset > 16 {
return "", s, false
}
if s[i-1] == '.' {
return "", s, false
}
if i-periodIndex-1 > 3 {
return "", s, false
}
}
return s[:i], s[i:], true
}
// ParseInteger parses an integer from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid integer. It returns the
// parsed integer and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseInteger(s string) (parsed int64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
return n, true
}
return 0, false
}
// ParseDecimal parses a decimal from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid decimal. It returns the
// parsed decimal and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseDecimal(s string) (parsed float64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if !strings.Contains(s, ".") {
return 0, false
}
if n, err := strconv.ParseFloat(s, 64); err == nil {
return n, true
}
return 0, false
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func consumeString(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '"' {
return "", s, false
}
for i := 1; i < len(s); i++ {
switch ch := s[i]; ch {
case '\\':
if i+1 >= len(s) {
return "", s, false
}
i++
if ch = s[i]; ch != '"' && ch != '\\' {
return "", s, false
}
case '"':
return s[:i+1], s[i+1:], true
default:
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
}
}
return "", s, false
}
// ParseString parses a Go string from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid string. It returns the
// parsed string and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func ParseString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeString(s); !ok || rest != "" {
return "", false
}
return s[1 : len(s)-1], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func consumeToken(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isTChar(ch) && !slices.Contains([]byte(":/"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// ParseToken parses a token from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid token. It returns the
// parsed token and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func ParseToken(s string) (parsed string, ok bool) {
if _, rest, ok := consumeToken(s); !ok || rest != "" {
return "", false
}
return s, true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func consumeByteSequence(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != ':' {
return "", s, false
}
for i := 1; i < len(s); i++ {
if ch := s[i]; ch == ':' {
return s[:i+1], s[i+1:], true
}
if ch := s[i]; !isAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("+/="), ch) {
return "", s, false
}
}
return "", s, false
}
// ParseByteSequence parses a byte sequence from a given HTTP Structured Field
// Values.
//
// The entire HTTP SFV string must consist of a valid byte sequence. It returns
// the parsed byte sequence and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func ParseByteSequence(s string) (parsed []byte, ok bool) {
if _, rest, ok := consumeByteSequence(s); !ok || rest != "" {
return nil, false
}
return []byte(s[1 : len(s)-1]), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func consumeBoolean(s string) (consumed, rest string, ok bool) {
if len(s) >= 2 && (s[:2] == "?0" || s[:2] == "?1") {
return s[:2], s[2:], true
}
return "", s, false
}
// ParseBoolean parses a boolean from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid boolean. It returns the
// parsed boolean and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func ParseBoolean(s string) (parsed bool, ok bool) {
if _, rest, ok := consumeBoolean(s); !ok || rest != "" {
return false, false
}
return s == "?1", true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func consumeDate(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '@' {
return "", s, false
}
if _, rest, ok = consumeIntegerOrDecimal(s[1:]); !ok {
return "", s, ok
}
consumed = s[:len(s)-len(rest)]
if slices.Contains([]byte(consumed), '.') {
return "", s, false
}
return consumed, rest, ok
}
// ParseDate parses a date from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid date. It returns the
// parsed date and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func ParseDate(s string) (parsed time.Time, ok bool) {
if _, rest, ok := consumeDate(s); !ok || rest != "" {
return time.Time{}, false
}
if n, ok := ParseInteger(s[1:]); !ok {
return time.Time{}, false
} else {
return time.Unix(n, 0), true
}
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func consumeDisplayString(s string) (consumed, rest string, ok bool) {
// To prevent excessive allocation, especially when input is large, we
// maintain a buffer of 4 bytes to keep track of the last rune we
// encounter. This way, we can validate that the display string conforms to
// UTF-8 without actually building the whole string.
var lastRune [4]byte
var runeLen int
isPartOfValidRune := func(ch byte) bool {
lastRune[runeLen] = ch
runeLen++
if utf8.FullRune(lastRune[:runeLen]) {
r, s := utf8.DecodeRune(lastRune[:runeLen])
if r == utf8.RuneError {
return false
}
copy(lastRune[:], lastRune[s:runeLen])
runeLen -= s
return true
}
return runeLen <= 4
}
if len(s) <= 1 || s[:2] != `%"` {
return "", s, false
}
i := 2
for i < len(s) {
ch := s[i]
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
switch ch {
case '"':
if runeLen > 0 {
return "", s, false
}
return s[:i+1], s[i+1:], true
case '%':
if i+2 >= len(s) {
return "", s, false
}
if ch, ok = decOctetHex(s[i+1], s[i+2]); !ok {
return "", s, ok
}
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i += 3
default:
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i++
}
}
return "", s, false
}
// ParseDisplayString parses a display string from a given HTTP Structured
// Field Values.
//
// The entire HTTP SFV string must consist of a valid display string. It
// returns the parsed display string and an ok boolean value, indicating
// success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func ParseDisplayString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeDisplayString(s); !ok || rest != "" {
return "", false
}
// consumeDisplayString() already validates that we have a valid display
// string. Therefore, we can just construct the display string, without
// validating it again.
s = s[2 : len(s)-1]
var b strings.Builder
for i := 0; i < len(s); {
if s[i] == '%' {
decoded, _ := decOctetHex(s[i+1], s[i+2])
b.WriteByte(decoded)
i += 3
continue
}
b.WriteByte(s[i])
i++
}
return b.String(), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-bare-item.
func consumeBareItem(s string) (consumed, rest string, ok bool) {
if len(s) == 0 {
return "", s, false
}
ch := s[0]
switch {
case ch == '-' || isDigit(ch):
return consumeIntegerOrDecimal(s)
case ch == '"':
return consumeString(s)
case ch == '*' || isAlpha(ch):
return consumeToken(s)
case ch == ':':
return consumeByteSequence(s)
case ch == '?':
return consumeBoolean(s)
case ch == '@':
return consumeDate(s)
case ch == '%':
return consumeDisplayString(s)
default:
return "", s, false
}
}

View file

@ -6,7 +6,10 @@
package socket
import "unsafe"
import (
"encoding/binary"
"unsafe"
)
func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
for i := range vs {
@ -31,5 +34,5 @@ func (h *msghdr) controllen() int {
}
func (h *msghdr) flags() int {
return int(NativeEndian.Uint32(h.Pad_cgo_2[:]))
return int(binary.NativeEndian.Uint32(h.Pad_cgo_2[:]))
}

View file

@ -7,6 +7,7 @@
package socket // import "golang.org/x/net/internal/socket"
import (
"encoding/binary"
"errors"
"net"
"runtime"
@ -58,7 +59,7 @@ func (o *Option) GetInt(c *Conn) (int, error) {
if o.Len == 1 {
return int(b[0]), nil
}
return int(NativeEndian.Uint32(b[:4])), nil
return int(binary.NativeEndian.Uint32(b[:4])), nil
}
// Set writes the option and value to the kernel.
@ -84,7 +85,7 @@ func (o *Option) SetInt(c *Conn, v int) error {
b = []byte{byte(v)}
} else {
var bb [4]byte
NativeEndian.PutUint32(bb[:o.Len], uint32(v))
binary.NativeEndian.PutUint32(bb[:o.Len], uint32(v))
b = bb[:4]
}
return o.set(c, b)

View file

@ -1,23 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socket
import (
"encoding/binary"
"unsafe"
)
// NativeEndian is the machine native endian implementation of ByteOrder.
var NativeEndian binary.ByteOrder
func init() {
i := uint32(1)
b := (*[4]byte)(unsafe.Pointer(&i))
if b[0] == 1 {
NativeEndian = binary.LittleEndian
} else {
NativeEndian = binary.BigEndian
}
}

View file

@ -36,7 +36,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip4 := ip.To4(); ip4 != nil {
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
default:
b[0] = sizeofSockaddrInet4
b[1] = sysAF_INET
@ -48,7 +48,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
default:
b[0] = sizeofSockaddrInet6
b[1] = sysAF_INET6
@ -56,7 +56,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
binary.BigEndian.PutUint16(b[2:4], uint16(port))
copy(b[8:24], ip6)
if zone != "" {
NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
binary.NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
}
return sizeofSockaddrInet6
}
@ -70,7 +70,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
var af int
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
af = int(NativeEndian.Uint16(b[:2]))
af = int(binary.NativeEndian.Uint16(b[:2]))
default:
af = int(b[1])
}
@ -89,7 +89,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
}
ip = make(net.IP, net.IPv6len)
copy(ip, b[8:24])
if id := int(NativeEndian.Uint32(b[24:28])); id > 0 {
if id := int(binary.NativeEndian.Uint32(b[24:28])); id > 0 {
zone = zoneCache.name(id)
}
}

View file

@ -4,27 +4,27 @@
package socket
type iovec struct {
Base *byte
Len uint64
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
Len uint32
Level int32
Type int32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
sizeofIovec = 0x10
sizeofMsghdr = 0x30
)

View file

@ -4,27 +4,27 @@
package socket
type iovec struct {
Base *byte
Len uint64
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
Len uint32
Level int32
Type int32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
sizeofIovec = 0x10
sizeofMsghdr = 0x30
)

View file

@ -297,7 +297,7 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter,
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// TODO(mikio): handle IO deadlines and cancellation if
// necessary
if _, err := rw.Write(b); err != nil {
return err

View file

@ -9,8 +9,6 @@ import (
"fmt"
"net"
"runtime"
"golang.org/x/net/internal/socket"
)
const (
@ -68,12 +66,12 @@ func (h *Header) Marshal() ([]byte, error) {
flagsAndFragOff := (h.FragOff & 0x1fff) | int(h.Flags<<13)
switch runtime.GOOS {
case "darwin", "ios", "dragonfly", "netbsd":
socket.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
socket.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
binary.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
case "freebsd":
if freebsdVersion < 1100000 {
socket.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
socket.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
binary.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
} else {
binary.BigEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.BigEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
@ -127,15 +125,15 @@ func (h *Header) Parse(b []byte) error {
h.Dst = net.IPv4(b[16], b[17], b[18], b[19])
switch runtime.GOOS {
case "darwin", "ios", "dragonfly", "netbsd":
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4])) + hdrlen
h.FragOff = int(socket.NativeEndian.Uint16(b[6:8]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4])) + hdrlen
h.FragOff = int(binary.NativeEndian.Uint16(b[6:8]))
case "freebsd":
if freebsdVersion < 1100000 {
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4]))
if freebsdVersion < 1000000 {
h.TotalLen += hdrlen
}
h.FragOff = int(socket.NativeEndian.Uint16(b[6:8]))
h.FragOff = int(binary.NativeEndian.Uint16(b[6:8]))
} else {
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))
h.FragOff = int(binary.BigEndian.Uint16(b[6:8]))

Some files were not shown because too many files have changed in this diff Show more