chore(deps): bump Go 1.21→1.24 and resync vendor for Pion WebRTC v4 compat

Pion webrtc/v4 (v4.2.11) requires Go 1.24+. Upstream datarhei was at
go 1.21.0. Bumping to go 1.24.0 pulls minor bumps across testify,
golang.org/x/{crypto,net,sync,sys,text,time,tools,mod}; vendor/ is
regenerated via 'go mod vendor' to reflect the new versions.

No application code changes; pure dep bump to unblock M1.
This commit is contained in:
Zac Gaetano 2026-04-17 08:43:31 -04:00
parent 262a393b8d
commit 651a9a3eb5
385 changed files with 55239 additions and 79639 deletions

22
go.mod
View file

@ -1,8 +1,6 @@
module github.com/datarhei/core/v16 module github.com/datarhei/core/v16
go 1.21.0 go 1.24.0
toolchain go1.22.1
require ( require (
github.com/99designs/gqlgen v0.17.47 github.com/99designs/gqlgen v0.17.47
@ -26,13 +24,13 @@ require (
github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_golang v1.19.1
github.com/puzpuzpuz/xsync/v3 v3.1.0 github.com/puzpuzpuz/xsync/v3 v3.1.0
github.com/shirou/gopsutil/v3 v3.24.4 github.com/shirou/gopsutil/v3 v3.24.4
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.11.1
github.com/swaggo/echo-swagger v1.4.1 github.com/swaggo/echo-swagger v1.4.1
github.com/swaggo/swag v1.16.3 github.com/swaggo/swag v1.16.3
github.com/vektah/gqlparser/v2 v2.5.12 github.com/vektah/gqlparser/v2 v2.5.12
github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeipuuv/gojsonschema v1.2.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
golang.org/x/mod v0.17.0 golang.org/x/mod v0.32.0
) )
require ( require (
@ -94,13 +92,13 @@ require (
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.23.0 // indirect golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.10.0 // indirect
golang.org/x/tools v0.21.0 // indirect golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect

35
go.sum
View file

@ -177,8 +177,9 @@ github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/swaggo/echo-swagger v1.4.1 h1:Yf0uPaJWp1uRtDloZALyLnvdBeoEL5Kc7DtnjzO/TUk= github.com/swaggo/echo-swagger v1.4.1 h1:Yf0uPaJWp1uRtDloZALyLnvdBeoEL5Kc7DtnjzO/TUk=
github.com/swaggo/echo-swagger v1.4.1/go.mod h1:C8bSi+9yH2FLZsnhqMZLIZddpUxZdBYuNHbtaS1Hljc= github.com/swaggo/echo-swagger v1.4.1/go.mod h1:C8bSi+9yH2FLZsnhqMZLIZddpUxZdBYuNHbtaS1Hljc=
github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw= github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw=
@ -222,14 +223,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -239,14 +240,14 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View file

@ -1,35 +1,35 @@
Developer Certificate of Origin Developer Certificate of Origin
Version 1.1 Version 1.1
Copyright (C) 2015- Klaus Post & Contributors. Copyright (C) 2015- Klaus Post & Contributors.
Email: klauspost@gmail.com Email: klauspost@gmail.com
Everyone is permitted to copy and distribute verbatim copies of this Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed. license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1 Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that: By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I (a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license have the right to submit it under the open source license
indicated in the file; or indicated in the file; or
(b) The contribution is based upon previous work that, to the best (b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that license and I have the right under that license to submit that
work with modifications, whether created in whole or in part work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated permitted to submit under a different license), as indicated
in the file; or in the file; or
(c) The contribution was provided directly to me by some other (c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified person who certified (a), (b) or (c) and I have not modified
it. it.
(d) I understand and agree that this project and the contribution (d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved. this project or the open source license(s) involved.

View file

@ -7,10 +7,13 @@ import (
"time" "time"
) )
type CompareType int // Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it.
type CompareType = compareResult
type compareResult int
const ( const (
compareLess CompareType = iota - 1 compareLess compareResult = iota - 1
compareEqual compareEqual
compareGreater compareGreater
) )
@ -39,7 +42,7 @@ var (
bytesType = reflect.TypeOf([]byte{}) bytesType = reflect.TypeOf([]byte{})
) )
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) {
obj1Value := reflect.ValueOf(obj1) obj1Value := reflect.ValueOf(obj1)
obj2Value := reflect.ValueOf(obj2) obj2Value := reflect.ValueOf(obj2)
@ -325,7 +328,13 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
} }
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) if timeObj1.Before(timeObj2) {
return compareLess, true
}
if timeObj1.Equal(timeObj2) {
return compareEqual, true
}
return compareGreater, true
} }
case reflect.Slice: case reflect.Slice:
{ {
@ -345,7 +354,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
} }
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true
} }
case reflect.Uintptr: case reflect.Uintptr:
{ {
@ -381,7 +390,8 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not greater than \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, failMessage, msgAndArgs...)
} }
// GreaterOrEqual asserts that the first element is greater than or equal to the second // GreaterOrEqual asserts that the first element is greater than or equal to the second
@ -394,7 +404,8 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not greater than or equal to \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, failMessage, msgAndArgs...)
} }
// Less asserts that the first element is less than the second // Less asserts that the first element is less than the second
@ -406,7 +417,8 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not less than \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareLess}, failMessage, msgAndArgs...)
} }
// LessOrEqual asserts that the first element is less than or equal to the second // LessOrEqual asserts that the first element is less than or equal to the second
@ -419,7 +431,8 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not less than or equal to \"%v\"", e1, e2)
return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, failMessage, msgAndArgs...)
} }
// Positive asserts that the specified element is positive // Positive asserts that the specified element is positive
@ -431,7 +444,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
h.Helper() h.Helper()
} }
zero := reflect.Zero(reflect.TypeOf(e)) zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not positive", e)
return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, failMessage, msgAndArgs...)
} }
// Negative asserts that the specified element is negative // Negative asserts that the specified element is negative
@ -443,10 +457,11 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
h.Helper() h.Helper()
} }
zero := reflect.Zero(reflect.TypeOf(e)) zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...) failMessage := fmt.Sprintf("\"%v\" is not negative", e)
return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, failMessage, msgAndArgs...)
} }
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
@ -459,17 +474,17 @@ func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedCompare
compareResult, isComparable := compare(e1, e2, e1Kind) compareResult, isComparable := compare(e1, e2, e1Kind)
if !isComparable { if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...) return Fail(t, fmt.Sprintf(`Can not compare type "%T"`, e1), msgAndArgs...)
} }
if !containsValue(allowedComparesResults, compareResult) { if !containsValue(allowedComparesResults, compareResult) {
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...) return Fail(t, failMessage, msgAndArgs...)
} }
return true return true
} }
func containsValue(values []CompareType, value CompareType) bool { func containsValue(values []compareResult, value compareResult) bool {
for _, v := range values { for _, v := range values {
if v == value { if v == value {
return true return true

View file

@ -50,10 +50,19 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
} }
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // Emptyf asserts that the given value is "empty".
// a slice or a channel with len == 0. //
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// assert.Emptyf(t, obj, "error message %s", "formatted") // assert.Emptyf(t, obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -104,8 +113,8 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// EqualValuesf asserts that two objects are equal or convertible to the same types // EqualValuesf asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
@ -117,10 +126,8 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri
// Errorf asserts that a function returned an error (i.e. not `nil`). // Errorf asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if assert.Errorf(t, err, "error message %s", "formatted") { // assert.Errorf(t, err, "error message %s", "formatted")
// assert.Equal(t, expectedErrorf, err)
// }
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -186,7 +193,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick
// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { // assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -438,7 +445,19 @@ func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interf
return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...) return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...)
} }
// IsNotTypef asserts that the specified objects are not of the same type.
//
// assert.IsNotTypef(t, &NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func IsNotTypef(t TestingT, theType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsNotType(t, theType, object, append([]interface{}{msg}, args...)...)
}
// IsTypef asserts that the specified objects are of the same type. // IsTypef asserts that the specified objects are of the same type.
//
// assert.IsTypef(t, &MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -568,8 +587,24 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
} }
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// a slice or a channel with len == 0. // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// assert.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
}
// NotEmptyf asserts that the specified object is NOT [Empty].
// //
// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { // if assert.NotEmptyf(t, obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -604,7 +639,16 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// NotErrorIsf asserts that at none of the errors in err's chain matches target. // NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -667,12 +711,15 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT // NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") // assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// assert.NotSubsetf(t, [1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -756,11 +803,15 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...) return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// Subsetf asserts that the specified list(array, slice...) or map contains all // Subsetf asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") // assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// assert.Subsetf(t, [1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View file

@ -92,10 +92,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st
return ElementsMatchf(a.t, listA, listB, msg, args...) return ElementsMatchf(a.t, listA, listB, msg, args...)
} }
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // Empty asserts that the given value is "empty".
// a slice or a channel with len == 0. //
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// a.Empty(obj) // a.Empty(obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -103,10 +112,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
return Empty(a.t, object, msgAndArgs...) return Empty(a.t, object, msgAndArgs...)
} }
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // Emptyf asserts that the given value is "empty".
// a slice or a channel with len == 0. //
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// a.Emptyf(obj, "error message %s", "formatted") // a.Emptyf(obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -186,8 +204,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...) return EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertible to the same types // EqualValues asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
@ -197,8 +215,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...) return EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertible to the same types // EqualValuesf asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
@ -224,10 +242,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string
// Error asserts that a function returned an error (i.e. not `nil`). // Error asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if a.Error(err) { // a.Error(err)
// assert.Equal(t, expectedError, err)
// }
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -297,10 +313,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter
// Errorf asserts that a function returned an error (i.e. not `nil`). // Errorf asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if a.Errorf(err, "error message %s", "formatted") { // a.Errorf(err, "error message %s", "formatted")
// assert.Equal(t, expectedErrorf, err)
// }
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -336,7 +350,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti
// a.EventuallyWithT(func(c *assert.CollectT) { // a.EventuallyWithT(func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -361,7 +375,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor
// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -868,7 +882,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in
return IsNonIncreasingf(a.t, object, msg, args...) return IsNonIncreasingf(a.t, object, msg, args...)
} }
// IsNotType asserts that the specified objects are not of the same type.
//
// a.IsNotType(&NotMyStruct{}, &MyStruct{})
func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNotType(a.t, theType, object, msgAndArgs...)
}
// IsNotTypef asserts that the specified objects are not of the same type.
//
// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNotTypef(a.t, theType, object, msg, args...)
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
//
// a.IsType(&MyStruct{}, &MyStruct{})
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -877,6 +913,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd
} }
// IsTypef asserts that the specified objects are of the same type. // IsTypef asserts that the specified objects are of the same type.
//
// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1128,8 +1166,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin
return NotContainsf(a.t, s, contains, msg, args...) return NotContainsf(a.t, s, contains, msg, args...)
} }
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// a slice or a channel with len == 0. // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true
//
// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true
func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotElementsMatch(a.t, listA, listB, msgAndArgs...)
}
// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotElementsMatchf(a.t, listA, listB, msg, args...)
}
// NotEmpty asserts that the specified object is NOT [Empty].
// //
// if a.NotEmpty(obj) { // if a.NotEmpty(obj) {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -1141,8 +1212,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) boo
return NotEmpty(a.t, object, msgAndArgs...) return NotEmpty(a.t, object, msgAndArgs...)
} }
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotEmptyf asserts that the specified object is NOT [Empty].
// a slice or a channel with len == 0.
// //
// if a.NotEmptyf(obj, "error message %s", "formatted") { // if a.NotEmptyf(obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -1200,7 +1270,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
return NotEqualf(a.t, expected, actual, msg, args...) return NotEqualf(a.t, expected, actual, msg, args...)
} }
// NotErrorIs asserts that at none of the errors in err's chain matches target. // NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorAs(a.t, err, target, msgAndArgs...)
}
// NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorAsf(a.t, err, target, msg, args...)
}
// NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool { func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
@ -1209,7 +1297,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface
return NotErrorIs(a.t, err, target, msgAndArgs...) return NotErrorIs(a.t, err, target, msgAndArgs...)
} }
// NotErrorIsf asserts that at none of the errors in err's chain matches target. // NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool { func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
@ -1326,12 +1414,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...) return NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) or map does NOT // NotSubset asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.NotSubset([1, 3, 4], [1, 2]) // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3}) // a.NotSubset({"x": 1, "y": 2}, {"z": 3})
// a.NotSubset([1, 3, 4], {1: "one", 2: "two"})
// a.NotSubset({"x": 1, "y": 2}, ["z"])
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1339,12 +1430,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...) return NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT // NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") // a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1504,11 +1598,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...) return Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) or map contains all // Subset asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.Subset([1, 2, 3], [1, 2]) // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1}) // a.Subset({"x": 1, "y": 2}, {"x": 1})
// a.Subset([1, 2, 3], {1: "one", 2: "two"})
// a.Subset({"x": 1, "y": 2}, ["x"])
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1516,11 +1614,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...) return Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) or map contains all // Subsetf asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") // a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View file

@ -6,7 +6,7 @@ import (
) )
// isOrdered checks that collection contains orderable elements. // isOrdered checks that collection contains orderable elements.
func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { func isOrdered(t TestingT, object interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
objKind := reflect.TypeOf(object).Kind() objKind := reflect.TypeOf(object).Kind()
if objKind != reflect.Slice && objKind != reflect.Array { if objKind != reflect.Slice && objKind != reflect.Array {
return false return false
@ -33,7 +33,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind) compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind)
if !isComparable { if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...) return Fail(t, fmt.Sprintf(`Can not compare type "%T" and "%T"`, value, prevValue), msgAndArgs...)
} }
if !containsValue(allowedComparesResults, compareResult) { if !containsValue(allowedComparesResults, compareResult) {
@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2}) // assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"}) // assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) return isOrdered(t, object, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
} }
// IsNonIncreasing asserts that the collection is not increasing // IsNonIncreasing asserts that the collection is not increasing
@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1}) // assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"}) // assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) return isOrdered(t, object, []compareResult{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
} }
// IsDecreasing asserts that the collection is decreasing // IsDecreasing asserts that the collection is decreasing
@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1}) // assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"}) // assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) return isOrdered(t, object, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
} }
// IsNonDecreasing asserts that the collection is not decreasing // IsNonDecreasing asserts that the collection is not decreasing
@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2}) // assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"}) // assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) return isOrdered(t, object, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
} }

View file

@ -19,7 +19,9 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib" "github.com/pmezard/go-difflib/difflib"
"gopkg.in/yaml.v3"
// Wrapper around gopkg.in/yaml.v3
"github.com/stretchr/testify/assert/yaml"
) )
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -45,6 +47,10 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool
// for table driven tests. // for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// PanicAssertionFunc is a common function prototype when validating a panic value. Can be useful
// for table driven tests.
type PanicAssertionFunc = func(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool
// Comparison is a custom function that returns true on success and false on failure // Comparison is a custom function that returns true on success and false on failure
type Comparison func() (success bool) type Comparison func() (success bool)
@ -204,59 +210,77 @@ the problem actually occurred in calling code.*/
// of each stack frame leading from the current test to the assert call that // of each stack frame leading from the current test to the assert call that
// failed. // failed.
func CallerInfo() []string { func CallerInfo() []string {
var pc uintptr var pc uintptr
var ok bool
var file string var file string
var line int var line int
var name string var name string
const stackFrameBufferSize = 10
pcs := make([]uintptr, stackFrameBufferSize)
callers := []string{} callers := []string{}
for i := 0; ; i++ { offset := 1
pc, file, line, ok = runtime.Caller(i)
if !ok { for {
// The breaks below failed to terminate the loop, and we ran off the n := runtime.Callers(offset, pcs)
// end of the call stack.
if n == 0 {
break break
} }
// This is a huge edge case, but it will panic if this is the case, see #180 frames := runtime.CallersFrames(pcs[:n])
if file == "<autogenerated>" {
break
}
f := runtime.FuncForPC(pc) for {
if f == nil { frame, more := frames.Next()
break pc = frame.PC
} file = frame.File
name = f.Name() line = frame.Line
// testing.tRunner is the standard library function that calls // This is a huge edge case, but it will panic if this is the case, see #180
// tests. Subtests are called directly by tRunner, without going through if file == "<autogenerated>" {
// the Test/Benchmark/Example function that contains the t.Run calls, so break
// with subtests we should break when we hit tRunner, without adding it }
// to the list of callers.
if name == "testing.tRunner" {
break
}
parts := strings.Split(file, "/") f := runtime.FuncForPC(pc)
if len(parts) > 1 { if f == nil {
filename := parts[len(parts)-1] break
dir := parts[len(parts)-2] }
if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" { name = f.Name()
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
// testing.tRunner is the standard library function that calls
// tests. Subtests are called directly by tRunner, without going through
// the Test/Benchmark/Example function that contains the t.Run calls, so
// with subtests we should break when we hit tRunner, without adding it
// to the list of callers.
if name == "testing.tRunner" {
break
}
parts := strings.Split(file, "/")
if len(parts) > 1 {
filename := parts[len(parts)-1]
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
}
}
// Drop the package
dotPos := strings.LastIndexByte(name, '.')
name = name[dotPos+1:]
if isTest(name, "Test") ||
isTest(name, "Benchmark") ||
isTest(name, "Example") {
break
}
if !more {
break
} }
} }
// Drop the package // Next batch
segments := strings.Split(name, ".") offset += cap(pcs)
name = segments[len(segments)-1]
if isTest(name, "Test") ||
isTest(name, "Benchmark") ||
isTest(name, "Example") {
break
}
} }
return callers return callers
@ -431,17 +455,34 @@ func NotImplements(t TestingT, interfaceObject interface{}, object interface{},
return true return true
} }
func isType(expectedType, object interface{}) bool {
return ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType))
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { //
// assert.IsType(t, &MyStruct{}, &MyStruct{})
func IsType(t TestingT, expectedType, object interface{}, msgAndArgs ...interface{}) bool {
if isType(expectedType, object) {
return true
}
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
return Fail(t, fmt.Sprintf("Object expected to be of type %T, but was %T", expectedType, object), msgAndArgs...)
}
if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { // IsNotType asserts that the specified objects are not of the same type.
return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) //
// assert.IsNotType(t, &NotMyStruct{}, &MyStruct{})
func IsNotType(t TestingT, theType, object interface{}, msgAndArgs ...interface{}) bool {
if !isType(theType, object) {
return true
} }
if h, ok := t.(tHelper); ok {
return true h.Helper()
}
return Fail(t, fmt.Sprintf("Object type expected to be different than %T", theType), msgAndArgs...)
} }
// Equal asserts that two objects are equal. // Equal asserts that two objects are equal.
@ -469,7 +510,6 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
} }
return true return true
} }
// validateEqualArgs checks whether provided arguments can be safely used in the // validateEqualArgs checks whether provided arguments can be safely used in the
@ -496,10 +536,17 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
h.Helper() h.Helper()
} }
if !samePointers(expected, actual) { same, ok := samePointers(expected, actual)
if !ok {
return Fail(t, "Both arguments must be pointers", msgAndArgs...)
}
if !same {
// both are pointers but not the same type & pointing to the same address
return Fail(t, fmt.Sprintf("Not same: \n"+ return Fail(t, fmt.Sprintf("Not same: \n"+
"expected: %p %#v\n"+ "expected: %p %#[1]v\n"+
"actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) "actual : %p %#[2]v",
expected, actual), msgAndArgs...)
} }
return true return true
@ -516,29 +563,37 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
h.Helper() h.Helper()
} }
if samePointers(expected, actual) { same, ok := samePointers(expected, actual)
if !ok {
// fails when the arguments are not pointers
return !(Fail(t, "Both arguments must be pointers", msgAndArgs...))
}
if same {
return Fail(t, fmt.Sprintf( return Fail(t, fmt.Sprintf(
"Expected and actual point to the same object: %p %#v", "Expected and actual point to the same object: %p %#[1]v",
expected, expected), msgAndArgs...) expected), msgAndArgs...)
} }
return true return true
} }
// samePointers compares two generic interface objects and returns whether // samePointers checks if two generic interface objects are pointers of the same
// they point to the same object // type pointing to the same object. It returns two values: same indicating if
func samePointers(first, second interface{}) bool { // they are the same type and point to the same object, and ok indicating that
// both inputs are pointers.
func samePointers(first, second interface{}) (same bool, ok bool) {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false return false, false // not both are pointers
} }
firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
if firstType != secondType { if firstType != secondType {
return false return false, true // both are pointers, but of different types
} }
// compare pointer addresses // compare pointer addresses
return first == second return first == second, true
} }
// formatUnequalValues takes two values of arbitrary types and returns string // formatUnequalValues takes two values of arbitrary types and returns string
@ -572,8 +627,8 @@ func truncatingFormat(data interface{}) string {
return value return value
} }
// EqualValues asserts that two objects are equal or convertible to the same types // EqualValues asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
@ -590,7 +645,6 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
} }
return true return true
} }
// EqualExportedValues asserts that the types of two objects are equal and their public // EqualExportedValues asserts that the types of two objects are equal and their public
@ -615,21 +669,6 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
} }
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
}
if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
}
expected = copyExportedFields(expected) expected = copyExportedFields(expected)
actual = copyExportedFields(actual) actual = copyExportedFields(actual)
@ -660,7 +699,6 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
} }
return Equal(t, expected, actual, msgAndArgs...) return Equal(t, expected, actual, msgAndArgs...)
} }
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
@ -710,37 +748,45 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// isEmpty gets whether the specified object is considered empty or not. // isEmpty gets whether the specified object is considered empty or not.
func isEmpty(object interface{}) bool { func isEmpty(object interface{}) bool {
// get nil case out of the way // get nil case out of the way
if object == nil { if object == nil {
return true return true
} }
objValue := reflect.ValueOf(object) return isEmptyValue(reflect.ValueOf(object))
switch objValue.Kind() {
// collection types are empty when they have no element
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty
case reflect.Ptr:
if objValue.IsNil() {
return true
}
deref := objValue.Elem().Interface()
return isEmpty(deref)
// for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
}
} }
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // isEmptyValue gets whether the specified reflect.Value is considered empty or not.
// a slice or a channel with len == 0. func isEmptyValue(objValue reflect.Value) bool {
if objValue.IsZero() {
return true
}
// Special cases of non-zero values that we consider empty
switch objValue.Kind() {
// collection types are empty when they have no element
// Note: array types are empty when they match their zero-initialized state.
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// non-nil pointers are empty if the value they point to is empty
case reflect.Ptr:
return isEmptyValue(objValue.Elem())
}
return false
}
// Empty asserts that the given value is "empty".
//
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// assert.Empty(t, obj) // assert.Empty(t, obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := isEmpty(object) pass := isEmpty(object)
if !pass { if !pass {
@ -751,11 +797,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
} }
return pass return pass
} }
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotEmpty asserts that the specified object is NOT [Empty].
// a slice or a channel with len == 0.
// //
// if assert.NotEmpty(t, obj) { // if assert.NotEmpty(t, obj) {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -770,7 +814,6 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
} }
return pass return pass
} }
// getLen tries to get the length of an object. // getLen tries to get the length of an object.
@ -814,7 +857,6 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
} }
return true return true
} }
// False asserts that the specified value is false. // False asserts that the specified value is false.
@ -829,7 +871,6 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
} }
return true return true
} }
// NotEqual asserts that the specified values are NOT equal. // NotEqual asserts that the specified values are NOT equal.
@ -852,7 +893,6 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
} }
return true return true
} }
// NotEqualValues asserts that two objects are not equal even when converted to the same type // NotEqualValues asserts that two objects are not equal even when converted to the same type
@ -875,7 +915,6 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (true, false) if element was not found. // return (true, false) if element was not found.
// return (true, true) if element was found. // return (true, true) if element was found.
func containsElement(list interface{}, element interface{}) (ok, found bool) { func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list) listValue := reflect.ValueOf(list)
listType := reflect.TypeOf(list) listType := reflect.TypeOf(list)
if listType == nil { if listType == nil {
@ -910,7 +949,6 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {
} }
} }
return true, false return true, false
} }
// Contains asserts that the specified string, list(array, slice...) or map contains the // Contains asserts that the specified string, list(array, slice...) or map contains the
@ -933,7 +971,6 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
} }
return true return true
} }
// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the
@ -956,14 +993,17 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
} }
return true return true
} }
// Subset asserts that the specified list(array, slice...) or map contains all // Subset asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// assert.Subset(t, [1, 2, 3], [1, 2]) // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1}) // assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
// assert.Subset(t, [1, 2, 3], {1: "one", 2: "two"})
// assert.Subset(t, {"x": 1, "y": 2}, ["x"])
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -978,7 +1018,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
} }
subsetKind := reflect.TypeOf(subset).Kind() subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
} }
@ -1002,6 +1042,13 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
} }
subsetList := reflect.ValueOf(subset) subsetList := reflect.ValueOf(subset)
if subsetKind == reflect.Map {
keys := make([]interface{}, subsetList.Len())
for idx, key := range subsetList.MapKeys() {
keys[idx] = key.Interface()
}
subsetList = reflect.ValueOf(keys)
}
for i := 0; i < subsetList.Len(); i++ { for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface() element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element) ok, found := containsElement(list, element)
@ -1016,12 +1063,15 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true return true
} }
// NotSubset asserts that the specified list(array, slice...) or map does NOT // NotSubset asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2]) // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) // assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
// assert.NotSubset(t, [1, 3, 4], {1: "one", 2: "two"})
// assert.NotSubset(t, {"x": 1, "y": 2}, ["z"])
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1036,7 +1086,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
} }
subsetKind := reflect.TypeOf(subset).Kind() subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
} }
@ -1060,11 +1110,18 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
} }
subsetList := reflect.ValueOf(subset) subsetList := reflect.ValueOf(subset)
if subsetKind == reflect.Map {
keys := make([]interface{}, subsetList.Len())
for idx, key := range subsetList.MapKeys() {
keys[idx] = key.Interface()
}
subsetList = reflect.ValueOf(keys)
}
for i := 0; i < subsetList.Len(); i++ { for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface() element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element) ok, found := containsElement(list, element)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) return Fail(t, fmt.Sprintf("%q could not be applied builtin len()", list), msgAndArgs...)
} }
if !found { if !found {
return true return true
@ -1170,6 +1227,39 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri
return msg.String() return msg.String()
} }
// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true
//
// assert.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true
func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isEmpty(listA) && isEmpty(listB) {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
if !isList(t, listA, msgAndArgs...) {
return Fail(t, "listA is not a list type", msgAndArgs...)
}
if !isList(t, listB, msgAndArgs...) {
return Fail(t, "listB is not a list type", msgAndArgs...)
}
extraA, extraB := diffLists(listA, listB)
if len(extraA) == 0 && len(extraB) == 0 {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
return true
}
// Condition uses a Comparison to assert a complex condition. // Condition uses a Comparison to assert a complex condition.
func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -1488,6 +1578,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
if err != nil { if err != nil {
return Fail(t, err.Error(), msgAndArgs...) return Fail(t, err.Error(), msgAndArgs...)
} }
if math.IsNaN(actualEpsilon) {
return Fail(t, "relative error is NaN", msgAndArgs...)
}
if actualEpsilon > epsilon { if actualEpsilon > epsilon {
return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+
" < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...) " < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...)
@ -1550,10 +1643,8 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
// Error asserts that a function returned an error (i.e. not `nil`). // Error asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if assert.Error(t, err) { // assert.Error(t, err)
// assert.Equal(t, expectedError, err)
// }
func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
if err == nil { if err == nil {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -1611,7 +1702,6 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in
// matchRegexp return true if a specified regexp matches a string. // matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool { func matchRegexp(rx interface{}, str interface{}) bool {
var r *regexp.Regexp var r *regexp.Regexp
if rr, ok := rx.(*regexp.Regexp); ok { if rr, ok := rx.(*regexp.Regexp); ok {
r = rr r = rr
@ -1619,8 +1709,14 @@ func matchRegexp(rx interface{}, str interface{}) bool {
r = regexp.MustCompile(fmt.Sprint(rx)) r = regexp.MustCompile(fmt.Sprint(rx))
} }
return (r.FindStringIndex(fmt.Sprint(str)) != nil) switch v := str.(type) {
case []byte:
return r.Match(v)
case string:
return r.MatchString(v)
default:
return r.MatchString(fmt.Sprint(v))
}
} }
// Regexp asserts that a specified regexp matches a string. // Regexp asserts that a specified regexp matches a string.
@ -1656,7 +1752,6 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf
} }
return !match return !match
} }
// Zero asserts that i is the zero value for its type. // Zero asserts that i is the zero value for its type.
@ -1767,6 +1862,11 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...) return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...)
} }
// Shortcut if same bytes
if actual == expected {
return true
}
if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil {
return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...) return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...)
} }
@ -1785,6 +1885,11 @@ func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid yaml.\nYAML parsing error: '%s'", expected, err.Error()), msgAndArgs...) return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid yaml.\nYAML parsing error: '%s'", expected, err.Error()), msgAndArgs...)
} }
// Shortcut if same bytes
if actual == expected {
return true
}
if err := yaml.Unmarshal([]byte(actual), &actualYAMLAsInterface); err != nil { if err := yaml.Unmarshal([]byte(actual), &actualYAMLAsInterface); err != nil {
return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid yaml.\nYAML error: '%s'", actual, err.Error()), msgAndArgs...) return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid yaml.\nYAML error: '%s'", actual, err.Error()), msgAndArgs...)
} }
@ -1872,7 +1977,7 @@ var spewConfigStringerEnabled = spew.ConfigState{
MaxDepth: 10, MaxDepth: 10,
} }
type tHelper interface { type tHelper = interface {
Helper() Helper()
} }
@ -1886,6 +1991,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
} }
ch := make(chan bool, 1) ch := make(chan bool, 1)
checkCond := func() { ch <- condition() }
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -1893,35 +1999,47 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
ticker := time.NewTicker(tick) ticker := time.NewTicker(tick)
defer ticker.Stop() defer ticker.Stop()
for tick := ticker.C; ; { var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select { select {
case <-timer.C: case <-timer.C:
return Fail(t, "Condition never satisfied", msgAndArgs...) return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick: case <-tickC:
tick = nil tickC = nil
go func() { ch <- condition() }() go checkCond()
case v := <-ch: case v := <-ch:
if v { if v {
return true return true
} }
tick = ticker.C tickC = ticker.C
} }
} }
} }
// CollectT implements the TestingT interface and collects all errors. // CollectT implements the TestingT interface and collects all errors.
type CollectT struct { type CollectT struct {
// A slice of errors. Non-nil slice denotes a failure.
// If it's non-nil but len(c.errors) == 0, this is also a failure
// obtained by direct c.FailNow() call.
errors []error errors []error
} }
// Helper is like [testing.T.Helper] but does nothing.
func (CollectT) Helper() {}
// Errorf collects the error. // Errorf collects the error.
func (c *CollectT) Errorf(format string, args ...interface{}) { func (c *CollectT) Errorf(format string, args ...interface{}) {
c.errors = append(c.errors, fmt.Errorf(format, args...)) c.errors = append(c.errors, fmt.Errorf(format, args...))
} }
// FailNow panics. // FailNow stops execution by calling runtime.Goexit.
func (*CollectT) FailNow() { func (c *CollectT) FailNow() {
panic("Assertion failed") c.fail()
runtime.Goexit()
} }
// Deprecated: That was a method for internal usage that should not have been published. Now just panics. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
@ -1934,6 +2052,16 @@ func (*CollectT) Copy(TestingT) {
panic("Copy() is deprecated") panic("Copy() is deprecated")
} }
func (c *CollectT) fail() {
if !c.failed() {
c.errors = []error{} // Make it non-nil to mark a failure.
}
}
func (c *CollectT) failed() bool {
return c.errors != nil
}
// EventuallyWithT asserts that given condition will be met in waitFor time, // EventuallyWithT asserts that given condition will be met in waitFor time,
// periodically checking target function each tick. In contrast to Eventually, // periodically checking target function each tick. In contrast to Eventually,
// it supplies a CollectT to the condition function, so that the condition // it supplies a CollectT to the condition function, so that the condition
@ -1951,14 +2079,22 @@ func (*CollectT) Copy(TestingT) {
// assert.EventuallyWithT(t, func(c *assert.CollectT) { // assert.EventuallyWithT(t, func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
var lastFinishedTickErrs []error var lastFinishedTickErrs []error
ch := make(chan []error, 1) ch := make(chan *CollectT, 1)
checkCond := func() {
collect := new(CollectT)
defer func() {
ch <- collect
}()
condition(collect)
}
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -1966,29 +2102,28 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
ticker := time.NewTicker(tick) ticker := time.NewTicker(tick)
defer ticker.Stop() defer ticker.Stop()
for tick := ticker.C; ; { var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select { select {
case <-timer.C: case <-timer.C:
for _, err := range lastFinishedTickErrs { for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
return Fail(t, "Condition never satisfied", msgAndArgs...) return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick: case <-tickC:
tick = nil tickC = nil
go func() { go checkCond()
collect := new(CollectT) case collect := <-ch:
defer func() { if !collect.failed() {
ch <- collect.errors
}()
condition(collect)
}()
case errs := <-ch:
if len(errs) == 0 {
return true return true
} }
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs lastFinishedTickErrs = collect.errors
tick = ticker.C tickC = ticker.C
} }
} }
} }
@ -2003,6 +2138,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D
} }
ch := make(chan bool, 1) ch := make(chan bool, 1)
checkCond := func() { ch <- condition() }
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -2010,18 +2146,23 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D
ticker := time.NewTicker(tick) ticker := time.NewTicker(tick)
defer ticker.Stop() defer ticker.Stop()
for tick := ticker.C; ; { var tickC <-chan time.Time
// Check the condition once first on the initial call.
go checkCond()
for {
select { select {
case <-timer.C: case <-timer.C:
return true return true
case <-tick: case <-tickC:
tick = nil tickC = nil
go func() { ch <- condition() }() go checkCond()
case v := <-ch: case v := <-ch:
if v { if v {
return Fail(t, "Condition satisfied", msgAndArgs...) return Fail(t, "Condition satisfied", msgAndArgs...)
} }
tick = ticker.C tickC = ticker.C
} }
} }
} }
@ -2039,9 +2180,12 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
var expectedText string var expectedText string
if target != nil { if target != nil {
expectedText = target.Error() expectedText = target.Error()
if err == nil {
return Fail(t, fmt.Sprintf("Expected error with %q in chain but got nil.", expectedText), msgAndArgs...)
}
} }
chain := buildErrorChainString(err) chain := buildErrorChainString(err, false)
return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+ return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+ "expected: %q\n"+
@ -2049,7 +2193,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
), msgAndArgs...) ), msgAndArgs...)
} }
// NotErrorIs asserts that at none of the errors in err's chain matches target. // NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -2064,7 +2208,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error() expectedText = target.Error()
} }
chain := buildErrorChainString(err) chain := buildErrorChainString(err, false)
return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+ "found: %q\n"+
@ -2082,24 +2226,70 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{
return true return true
} }
chain := buildErrorChainString(err) expectedType := reflect.TypeOf(target).Elem().String()
if err == nil {
return Fail(t, fmt.Sprintf("An error is expected but got nil.\n"+
"expected: %s", expectedType), msgAndArgs...)
}
chain := buildErrorChainString(err, true)
return Fail(t, fmt.Sprintf("Should be in error chain:\n"+ return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+ "expected: %s\n"+
"in chain: %s", target, chain, "in chain: %s", expectedType, chain,
), msgAndArgs...) ), msgAndArgs...)
} }
func buildErrorChainString(err error) string { // NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !errors.As(err, target) {
return true
}
chain := buildErrorChainString(err, true)
return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %s\n"+
"in chain: %s", reflect.TypeOf(target).Elem().String(), chain,
), msgAndArgs...)
}
func unwrapAll(err error) (errs []error) {
errs = append(errs, err)
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, unwrapAll(err)...)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
}
return
}
func buildErrorChainString(err error, withType bool) string {
if err == nil { if err == nil {
return "" return ""
} }
e := errors.Unwrap(err) var chain string
chain := fmt.Sprintf("%q", err.Error()) errs := unwrapAll(err)
for e != nil { for i := range errs {
chain += fmt.Sprintf("\n\t%q", e.Error()) if i != 0 {
e = errors.Unwrap(e) chain += "\n\t"
}
chain += fmt.Sprintf("%q", errs[i].Error())
if withType {
chain += fmt.Sprintf(" (%T)", errs[i])
}
} }
return chain return chain
} }

View file

@ -1,5 +1,9 @@
// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. // Package assert provides a set of comprehensive testing tools for use with the normal Go testing system.
// //
// # Note
//
// All functions in this package return a bool value indicating whether the assertion has passed.
//
// # Example Usage // # Example Usage
// //
// The following is a complete example using assert in a standard test function: // The following is a complete example using assert in a standard test function:

View file

@ -138,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if !contains { if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) Fail(t, fmt.Sprintf("Expected response body for %q to contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return contains return contains
@ -158,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if contains { if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) Fail(t, fmt.Sprintf("Expected response body for %q to NOT contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return !contains return !contains

View file

@ -0,0 +1,24 @@
//go:build testify_yaml_custom && !testify_yaml_fail && !testify_yaml_default
// Package yaml is an implementation of YAML functions that calls a pluggable implementation.
//
// This implementation is selected with the testify_yaml_custom build tag.
//
// go test -tags testify_yaml_custom
//
// This implementation can be used at build time to replace the default implementation
// to avoid linking with [gopkg.in/yaml.v3].
//
// In your test package:
//
// import assertYaml "github.com/stretchr/testify/assert/yaml"
//
// func init() {
// assertYaml.Unmarshal = func (in []byte, out interface{}) error {
// // ...
// return nil
// }
// }
package yaml
var Unmarshal func(in []byte, out interface{}) error

View file

@ -0,0 +1,36 @@
//go:build !testify_yaml_fail && !testify_yaml_custom
// Package yaml is just an indirection to handle YAML deserialization.
//
// This package is just an indirection that allows the builder to override the
// indirection with an alternative implementation of this package that uses
// another implementation of YAML deserialization. This allows to not either not
// use YAML deserialization at all, or to use another implementation than
// [gopkg.in/yaml.v3] (for example for license compatibility reasons, see [PR #1120]).
//
// Alternative implementations are selected using build tags:
//
// - testify_yaml_fail: [Unmarshal] always fails with an error
// - testify_yaml_custom: [Unmarshal] is a variable. Caller must initialize it
// before calling any of [github.com/stretchr/testify/assert.YAMLEq] or
// [github.com/stretchr/testify/assert.YAMLEqf].
//
// Usage:
//
// go test -tags testify_yaml_fail
//
// You can check with "go list" which implementation is linked:
//
// go list -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
// go list -tags testify_yaml_fail -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
// go list -tags testify_yaml_custom -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml
//
// [PR #1120]: https://github.com/stretchr/testify/pull/1120
package yaml
import goyaml "gopkg.in/yaml.v3"
// Unmarshal is just a wrapper of [gopkg.in/yaml.v3.Unmarshal].
func Unmarshal(in []byte, out interface{}) error {
return goyaml.Unmarshal(in, out)
}

View file

@ -0,0 +1,17 @@
//go:build testify_yaml_fail && !testify_yaml_custom && !testify_yaml_default
// Package yaml is an implementation of YAML functions that always fail.
//
// This implementation can be used at build time to replace the default implementation
// to avoid linking with [gopkg.in/yaml.v3]:
//
// go test -tags testify_yaml_fail
package yaml
import "errors"
var errNotImplemented = errors.New("YAML functions are not available (see https://pkg.go.dev/github.com/stretchr/testify/assert/yaml)")
func Unmarshal([]byte, interface{}) error {
return errNotImplemented
}

View file

@ -23,6 +23,8 @@
// //
// The `require` package have same global functions as in the `assert` package, // The `require` package have same global functions as in the `assert` package,
// but instead of returning a boolean result they call `t.FailNow()`. // but instead of returning a boolean result they call `t.FailNow()`.
// A consequence of this is that it must be called from the goroutine running
// the test function, not from other goroutines created during the test.
// //
// Every assertion function also takes an optional string message as the final argument, // Every assertion function also takes an optional string message as the final argument,
// allowing custom error messages to be appended to the message the assertion method outputs. // allowing custom error messages to be appended to the message the assertion method outputs.

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
{{.Comment}} {{ replace .Comment "assert." "require."}}
func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { func {{.DocInfo.Name}}(t TestingT, {{.Params}}) {
if h, ok := t.(tHelper); ok { h.Helper() } if h, ok := t.(tHelper); ok { h.Helper() }
if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return } if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }

View file

@ -93,10 +93,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st
ElementsMatchf(a.t, listA, listB, msg, args...) ElementsMatchf(a.t, listA, listB, msg, args...)
} }
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // Empty asserts that the given value is "empty".
// a slice or a channel with len == 0. //
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// a.Empty(obj) // a.Empty(obj)
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) { func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -104,10 +113,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) {
Empty(a.t, object, msgAndArgs...) Empty(a.t, object, msgAndArgs...)
} }
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // Emptyf asserts that the given value is "empty".
// a slice or a channel with len == 0. //
// [Zero values] are "empty".
//
// Arrays are "empty" if every element is the zero value of the type (stricter than "empty").
//
// Slices, maps and channels with zero length are "empty".
//
// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty".
// //
// a.Emptyf(obj, "error message %s", "formatted") // a.Emptyf(obj, "error message %s", "formatted")
//
// [Zero values]: https://go.dev/ref/spec#The_zero_value
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) { func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -187,8 +205,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
EqualExportedValuesf(a.t, expected, actual, msg, args...) EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertible to the same types // EqualValues asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
@ -198,8 +216,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
EqualValues(a.t, expected, actual, msgAndArgs...) EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertible to the same types // EqualValuesf asserts that two objects are equal or convertible to the larger
// and equal. // type and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) { func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) {
@ -225,10 +243,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string
// Error asserts that a function returned an error (i.e. not `nil`). // Error asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if a.Error(err) { // a.Error(err)
// assert.Equal(t, expectedError, err)
// }
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) { func (a *Assertions) Error(err error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -298,10 +314,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter
// Errorf asserts that a function returned an error (i.e. not `nil`). // Errorf asserts that a function returned an error (i.e. not `nil`).
// //
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// if a.Errorf(err, "error message %s", "formatted") { // a.Errorf(err, "error message %s", "formatted")
// assert.Equal(t, expectedErrorf, err)
// }
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) { func (a *Assertions) Errorf(err error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -337,7 +351,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti
// a.EventuallyWithT(func(c *assert.CollectT) { // a.EventuallyWithT(func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -362,7 +376,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), w
// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick // // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true") // assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") // }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
func (a *Assertions) EventuallyWithTf(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { func (a *Assertions) EventuallyWithTf(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -869,7 +883,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in
IsNonIncreasingf(a.t, object, msg, args...) IsNonIncreasingf(a.t, object, msg, args...)
} }
// IsNotType asserts that the specified objects are not of the same type.
//
// a.IsNotType(&NotMyStruct{}, &MyStruct{})
func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsNotType(a.t, theType, object, msgAndArgs...)
}
// IsNotTypef asserts that the specified objects are not of the same type.
//
// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
IsNotTypef(a.t, theType, object, msg, args...)
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
//
// a.IsType(&MyStruct{}, &MyStruct{})
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -878,6 +914,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd
} }
// IsTypef asserts that the specified objects are of the same type. // IsTypef asserts that the specified objects are of the same type.
//
// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted")
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) { func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1129,8 +1167,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin
NotContainsf(a.t, s, contains, msg, args...) NotContainsf(a.t, s, contains, msg, args...)
} }
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// a slice or a channel with len == 0. // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false
//
// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true
//
// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true
func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotElementsMatch(a.t, listA, listB, msgAndArgs...)
}
// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false
//
// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true
//
// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true
func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotElementsMatchf(a.t, listA, listB, msg, args...)
}
// NotEmpty asserts that the specified object is NOT [Empty].
// //
// if a.NotEmpty(obj) { // if a.NotEmpty(obj) {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -1142,8 +1213,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) {
NotEmpty(a.t, object, msgAndArgs...) NotEmpty(a.t, object, msgAndArgs...)
} }
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // NotEmptyf asserts that the specified object is NOT [Empty].
// a slice or a channel with len == 0.
// //
// if a.NotEmptyf(obj, "error message %s", "formatted") { // if a.NotEmptyf(obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
@ -1201,7 +1271,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
NotEqualf(a.t, expected, actual, msg, args...) NotEqualf(a.t, expected, actual, msg, args...)
} }
// NotErrorIs asserts that at none of the errors in err's chain matches target. // NotErrorAs asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotErrorAs(a.t, err, target, msgAndArgs...)
}
// NotErrorAsf asserts that none of the errors in err's chain matches target,
// but if so, sets target to that error value.
func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotErrorAsf(a.t, err, target, msg, args...)
}
// NotErrorIs asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) { func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
@ -1210,7 +1298,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface
NotErrorIs(a.t, err, target, msgAndArgs...) NotErrorIs(a.t, err, target, msgAndArgs...)
} }
// NotErrorIsf asserts that at none of the errors in err's chain matches target. // NotErrorIsf asserts that none of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) { func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
@ -1327,12 +1415,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
NotSamef(a.t, expected, actual, msg, args...) NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) or map does NOT // NotSubset asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.NotSubset([1, 3, 4], [1, 2]) // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3}) // a.NotSubset({"x": 1, "y": 2}, {"z": 3})
// a.NotSubset([1, 3, 4], {1: "one", 2: "two"})
// a.NotSubset({"x": 1, "y": 2}, ["z"])
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1340,12 +1431,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
NotSubset(a.t, list, subset, msgAndArgs...) NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT // NotSubsetf asserts that the list (array, slice, or map) does NOT contain all
// contain all elements given in the specified subset list(array, slice...) or // elements given in the subset (array, slice, or map).
// map. // Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") // a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1505,11 +1599,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
Samef(a.t, expected, actual, msg, args...) Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) or map contains all // Subset asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.Subset([1, 2, 3], [1, 2]) // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1}) // a.Subset({"x": 1, "y": 2}, {"x": 1})
// a.Subset([1, 2, 3], {1: "one", 2: "two"})
// a.Subset({"x": 1, "y": 2}, ["x"])
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1517,11 +1615,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
Subset(a.t, list, subset, msgAndArgs...) Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) or map contains all // Subsetf asserts that the list (array, slice, or map) contains all elements
// elements given in the specified subset list(array, slice...) or map. // given in the subset (array, slice, or map).
// Map elements are key-value pairs unless compared with an array or slice where
// only the map key is evaluated.
// //
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") // a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View file

@ -6,7 +6,7 @@ type TestingT interface {
FailNow() FailNow()
} }
type tHelper interface { type tHelper = interface {
Helper() Helper()
} }

4
vendor/golang.org/x/crypto/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the in the documentation and/or other materials provided with the
distribution. distribution.
* Neither the name of Google Inc. nor the names of its * Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from contributors may be used to endorse or promote products derived from
this software without specific prior written permission. this software without specific prior written permission.

View file

@ -31,12 +31,11 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"net"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -353,6 +352,10 @@ func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization
if _, err := c.Discover(ctx); err != nil { if _, err := c.Discover(ctx); err != nil {
return nil, err return nil, err
} }
if c.dir.AuthzURL == "" {
// Pre-Authorization is unsupported
return nil, errPreAuthorizationNotSupported
}
type authzID struct { type authzID struct {
Type string `json:"type"` Type string `json:"type"`
@ -467,7 +470,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
// while waiting for a final authorization status. // while waiting for a final authorization status.
d := retryAfter(res.Header.Get("Retry-After")) d := retryAfter(res.Header.Get("Retry-After"))
if d == 0 { if d == 0 {
// Given that the fastest challenges TLS-SNI and HTTP-01 // Given that the fastest challenges TLS-ALPN and HTTP-01
// require a CA to make at least 1 network round trip // require a CA to make at least 1 network round trip
// and most likely persist a challenge state, // and most likely persist a challenge state,
// this default delay seems reasonable. // this default delay seems reasonable.
@ -514,7 +517,11 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
return nil, err return nil, err
} }
res, err := c.post(ctx, nil, chal.URI, json.RawMessage("{}"), wantStatus( payload := json.RawMessage("{}")
if len(chal.Payload) != 0 {
payload = chal.Payload
}
res, err := c.post(ctx, nil, chal.URI, payload, wantStatus(
http.StatusOK, // according to the spec http.StatusOK, // according to the spec
http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md) http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md)
)) ))
@ -564,50 +571,28 @@ func (c *Client) HTTP01ChallengePath(token string) string {
} }
// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response. // TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response.
// Always returns an error.
// //
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. // Deprecated: This challenge type was only present in pre-standardized ACME
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { // protocol drafts and is insecure for use in shared hosting environments.
ka, err := keyAuth(c.Key.Public(), token) func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (tls.Certificate, string, error) {
if err != nil { return tls.Certificate{}, "", errPreRFC
return tls.Certificate{}, "", err
}
b := sha256.Sum256([]byte(ka))
h := hex.EncodeToString(b[:])
name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:])
cert, err = tlsChallengeCert([]string{name}, opt)
if err != nil {
return tls.Certificate{}, "", err
}
return cert, name, nil
} }
// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response. // TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response.
// Always returns an error.
// //
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. // Deprecated: This challenge type was only present in pre-standardized ACME
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { // protocol drafts and is insecure for use in shared hosting environments.
b := sha256.Sum256([]byte(token)) func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (tls.Certificate, string, error) {
h := hex.EncodeToString(b[:]) return tls.Certificate{}, "", errPreRFC
sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:])
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, "", err
}
b = sha256.Sum256([]byte(ka))
h = hex.EncodeToString(b[:])
sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:])
cert, err = tlsChallengeCert([]string{sanA, sanB}, opt)
if err != nil {
return tls.Certificate{}, "", err
}
return cert, sanA, nil
} }
// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response. // TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response.
// Servers can present the certificate to validate the challenge and prove control // Servers can present the certificate to validate the challenge and prove control
// over a domain name. For more details on TLS-ALPN-01 see // over an identifier (either a DNS name or the textual form of an IPv4 or IPv6
// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3 // address). For more details on TLS-ALPN-01 see
// https://www.rfc-editor.org/rfc/rfc8737 and https://www.rfc-editor.org/rfc/rfc8738
// //
// The token argument is a Challenge.Token value. // The token argument is a Challenge.Token value.
// If a WithKey option is provided, its private part signs the returned cert, // If a WithKey option is provided, its private part signs the returned cert,
@ -615,9 +600,13 @@ func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tl
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. // If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
// //
// The returned certificate is valid for the next 24 hours and must be presented only when // The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol // the server name in the TLS ClientHello matches the identifier, and the special acme-tls/1 ALPN protocol
// has been specified. // has been specified.
func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) { //
// Validation requests for IP address identifiers will use the reverse DNS form in the server name
// in the TLS ClientHello since the SNI extension is not supported for IP addresses.
// See RFC 8738 Section 6 for more information.
func (c *Client) TLSALPN01ChallengeCert(token, identifier string, opt ...CertOption) (cert tls.Certificate, err error) {
ka, err := keyAuth(c.Key.Public(), token) ka, err := keyAuth(c.Key.Public(), token)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
@ -647,7 +636,7 @@ func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption)
} }
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension) tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension)
newOpt = append(newOpt, WithTemplate(tmpl)) newOpt = append(newOpt, WithTemplate(tmpl))
return tlsChallengeCert([]string{domain}, newOpt) return tlsChallengeCert(identifier, newOpt)
} }
// popNonce returns a nonce value previously stored with c.addNonce // popNonce returns a nonce value previously stored with c.addNonce
@ -701,7 +690,7 @@ func (c *Client) addNonce(h http.Header) {
} }
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
r, err := http.NewRequest("HEAD", url, nil) r, err := http.NewRequestWithContext(ctx, "HEAD", url, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -765,11 +754,15 @@ func defaultTLSChallengeCertTemplate() *x509.Certificate {
} }
} }
// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges // tlsChallengeCert creates a temporary certificate for TLS-ALPN challenges
// with the given SANs and auto-generated public/private key pair. // for the given identifier, using an auto-generated public/private key pair.
// The Subject Common Name is set to the first SAN to aid debugging. //
// If the provided identifier is a domain name, it will be used as a DNS type SAN and for the
// subject common name. If the provided identifier is an IP address it will be used as an IP type
// SAN.
//
// To create a cert with a custom key pair, specify WithKey option. // To create a cert with a custom key pair, specify WithKey option.
func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { func tlsChallengeCert(identifier string, opt []CertOption) (tls.Certificate, error) {
var key crypto.Signer var key crypto.Signer
tmpl := defaultTLSChallengeCertTemplate() tmpl := defaultTLSChallengeCertTemplate()
for _, o := range opt { for _, o := range opt {
@ -793,9 +786,12 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
return tls.Certificate{}, err return tls.Certificate{}, err
} }
} }
tmpl.DNSNames = san
if len(san) > 0 { if ip := net.ParseIP(identifier); ip != nil {
tmpl.Subject.CommonName = san[0] tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{identifier}
tmpl.Subject.CommonName = identifier
} }
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
@ -808,11 +804,5 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
}, nil }, nil
} }
// encodePEM returns b encoded as PEM with block of type typ.
func encodePEM(typ string, b []byte) []byte {
pb := &pem.Block{Type: typ, Bytes: b}
return pem.EncodeToMemory(pb)
}
// timeNow is time.Now, except in tests which can mess with it. // timeNow is time.Now, except in tests which can mess with it.
var timeNow = time.Now var timeNow = time.Now

View file

@ -134,7 +134,8 @@ type Manager struct {
// RenewBefore optionally specifies how early certificates should // RenewBefore optionally specifies how early certificates should
// be renewed before they expire. // be renewed before they expire.
// //
// If zero, they're renewed 30 days before expiration. // If zero, they're renewed at the lesser of 30 days or
// 1/3 of the certificate lifetime.
RenewBefore time.Duration RenewBefore time.Duration
// Client is used to perform low-level operations, such as account registration // Client is used to perform low-level operations, such as account registration
@ -292,6 +293,10 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
} }
// regular domain // regular domain
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
ck := certKey{ ck := certKey{
domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114 domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114
isRSA: !supportsECDSA(hello), isRSA: !supportsECDSA(hello),
@ -305,9 +310,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
} }
// first-time // first-time
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
cert, err = m.createCert(ctx, ck) cert, err = m.createCert(ctx, ck)
if err != nil { if err != nil {
return nil, err return nil, err
@ -463,7 +465,7 @@ func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error
leaf: cert.Leaf, leaf: cert.Leaf,
} }
m.state[ck] = s m.state[ck] = s
m.startRenew(ck, s.key, s.leaf.NotAfter) m.startRenew(ck, s.key, s.leaf.NotBefore, s.leaf.NotAfter)
return cert, nil return cert, nil
} }
@ -609,7 +611,7 @@ func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate,
} }
state.cert = der state.cert = der
state.leaf = leaf state.leaf = leaf
m.startRenew(ck, state.key, state.leaf.NotAfter) m.startRenew(ck, state.key, state.leaf.NotBefore, state.leaf.NotAfter)
return state.tlscert() return state.tlscert()
} }
@ -907,7 +909,7 @@ func httpTokenCacheKey(tokenPath string) string {
// //
// The key argument is a certificate private key. // The key argument is a certificate private key.
// The exp argument is the cert expiration time (NotAfter). // The exp argument is the cert expiration time (NotAfter).
func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) { func (m *Manager) startRenew(ck certKey, key crypto.Signer, notBefore, notAfter time.Time) {
m.renewalMu.Lock() m.renewalMu.Lock()
defer m.renewalMu.Unlock() defer m.renewalMu.Unlock()
if m.renewal[ck] != nil { if m.renewal[ck] != nil {
@ -919,7 +921,7 @@ func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) {
} }
dr := &domainRenewal{m: m, ck: ck, key: key} dr := &domainRenewal{m: m, ck: ck, key: key}
m.renewal[ck] = dr m.renewal[ck] = dr
dr.start(exp) dr.start(notBefore, notAfter)
} }
// stopRenew stops all currently running cert renewal timers. // stopRenew stops all currently running cert renewal timers.
@ -1027,13 +1029,6 @@ func (m *Manager) hostPolicy() HostPolicy {
return defaultHostPolicy return defaultHostPolicy
} }
func (m *Manager) renewBefore() time.Duration {
if m.RenewBefore > renewJitter {
return m.RenewBefore
}
return 720 * time.Hour // 30 days
}
func (m *Manager) now() time.Time { func (m *Manager) now() time.Time {
if m.nowFunc != nil { if m.nowFunc != nil {
return m.nowFunc() return m.nowFunc()

View file

@ -10,7 +10,6 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"time" "time"
) )
@ -124,32 +123,13 @@ func (ln *listener) Close() error {
return ln.tcpListener.Close() return ln.tcpListener.Close()
} }
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string { func cacheDir() string {
const base = "golang-autocert" const base = "golang-autocert"
switch runtime.GOOS { cache, err := os.UserCacheDir()
case "darwin": if err != nil {
return filepath.Join(homeDir(), "Library", "Caches", base) // Fall back to the root directory.
case "windows": cache = "/.cache"
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
} }
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base) return filepath.Join(cache, base)
}
return filepath.Join(homeDir(), ".cache", base)
} }

View file

@ -11,9 +11,6 @@ import (
"time" "time"
) )
// renewJitter is the maximum deviation from Manager.RenewBefore.
const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers // domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert. // renewing a single domain's cert.
type domainRenewal struct { type domainRenewal struct {
@ -30,13 +27,13 @@ type domainRenewal struct {
// defined by the certificate expiration time exp. // defined by the certificate expiration time exp.
// //
// If the timer is already started, calling start is a noop. // If the timer is already started, calling start is a noop.
func (dr *domainRenewal) start(exp time.Time) { func (dr *domainRenewal) start(notBefore, notAfter time.Time) {
dr.timerMu.Lock() dr.timerMu.Lock()
defer dr.timerMu.Unlock() defer dr.timerMu.Unlock()
if dr.timer != nil { if dr.timer != nil {
return return
} }
dr.timer = time.AfterFunc(dr.next(exp), dr.renew) dr.timer = time.AfterFunc(dr.next(notBefore, notAfter), dr.renew)
} }
// stop stops the cert renewal timer and waits for any in-flight calls to renew // stop stops the cert renewal timer and waits for any in-flight calls to renew
@ -79,7 +76,7 @@ func (dr *domainRenewal) renew() {
// TODO: rotate dr.key at some point? // TODO: rotate dr.key at some point?
next, err := dr.do(ctx) next, err := dr.do(ctx)
if err != nil { if err != nil {
next = renewJitter / 2 next = time.Hour / 2
next += time.Duration(pseudoRand.int63n(int64(next))) next += time.Duration(pseudoRand.int63n(int64(next)))
} }
testDidRenewLoop(next, err) testDidRenewLoop(next, err)
@ -107,8 +104,8 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment // a race is likely unavoidable in a distributed environment
// but we try nonetheless // but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil { if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
next := dr.next(tlscert.Leaf.NotAfter) next := dr.next(tlscert.Leaf.NotBefore, tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter { if next > 0 {
signer, ok := tlscert.PrivateKey.(crypto.Signer) signer, ok := tlscert.PrivateKey.(crypto.Signer)
if ok { if ok {
state := &certState{ state := &certState{
@ -139,18 +136,23 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
return 0, err return 0, err
} }
dr.updateState(state) dr.updateState(state)
return dr.next(leaf.NotAfter), nil return dr.next(leaf.NotBefore, leaf.NotAfter), nil
} }
func (dr *domainRenewal) next(expiry time.Time) time.Duration { // next returns the wait time before the next renewal should start.
d := expiry.Sub(dr.m.now()) - dr.m.renewBefore() // If manager.RenewBefore is set, it uses that capped at 30 days,
// add a bit of randomness to renew deadline // otherwise it uses a default of 1/3 of the cert lifetime.
n := pseudoRand.int63n(int64(renewJitter)) // It builds in a jitter of 10% of the renew threshold, capped at 1 hour.
d -= time.Duration(n) func (dr *domainRenewal) next(notBefore, notAfter time.Time) time.Duration {
if d < 0 { threshold := min(notAfter.Sub(notBefore)/3, 30*24*time.Hour)
return 0 if dr.m.RenewBefore > 0 {
threshold = min(dr.m.RenewBefore, 30*24*time.Hour)
} }
return d maxJitter := min(threshold/10, time.Hour)
jitter := pseudoRand.int63n(int64(maxJitter))
renewAt := notAfter.Add(-(threshold - time.Duration(jitter)))
renewWait := renewAt.Sub(dr.m.now())
return max(0, renewWait)
} }
var testDidRenewLoop = func(next time.Duration, err error) {} var testDidRenewLoop = func(next time.Duration, err error) {}

View file

@ -15,6 +15,7 @@ import (
"io" "io"
"math/big" "math/big"
"net/http" "net/http"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -65,7 +66,7 @@ func (c *Client) retryTimer() *retryTimer {
// The n argument is always bounded between 1 and 30. // The n argument is always bounded between 1 and 30.
// The returned value is always greater than 0. // The returned value is always greater than 0.
func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration { func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
const max = 10 * time.Second const maxVal = 10 * time.Second
var jitter time.Duration var jitter time.Duration
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil { if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
// Set the minimum to 1ms to avoid a case where // Set the minimum to 1ms to avoid a case where
@ -85,10 +86,7 @@ func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
n = 30 n = 30
} }
d := time.Duration(1<<uint(n-1))*time.Second + jitter d := time.Duration(1<<uint(n-1))*time.Second + jitter
if d > max { return min(d, maxVal)
return max
}
return d
} }
// retryAfter parses a Retry-After HTTP header value, // retryAfter parses a Retry-After HTTP header value,
@ -130,7 +128,7 @@ func wantStatus(codes ...int) resOkay {
func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
retry := c.retryTimer() retry := c.retryTimer()
for { for {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -230,7 +228,7 @@ func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
req, err := http.NewRequest("POST", url, bytes.NewReader(b)) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -271,9 +269,27 @@ func (c *Client) httpClient() *http.Client {
} }
// packageVersion is the version of the module that contains this package, for // packageVersion is the version of the module that contains this package, for
// sending as part of the User-Agent header. It's set in version_go112.go. // sending as part of the User-Agent header.
var packageVersion string var packageVersion string
func init() {
// Set packageVersion if the binary was built in modules mode and x/crypto
// was not replaced with a different module.
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
for _, m := range info.Deps {
if m.Path != "golang.org/x/crypto" {
continue
}
if m.Replace == nil {
packageVersion = m.Version
}
break
}
}
// userAgent returns the User-Agent header value. It includes the package name, // userAgent returns the User-Agent header value. It includes the package name,
// the module version (if available), and the c.UserAgent value (if set). // the module version (if available), and the c.UserAgent value (if set).
func (c *Client) userAgent() string { func (c *Client) userAgent() string {

View file

@ -92,7 +92,7 @@ func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, ur
if err != nil { if err != nil {
return nil, err return nil, err
} }
phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON)) phead := base64.RawURLEncoding.EncodeToString(phJSON)
var payload string var payload string
if val, ok := claimset.(string); ok { if val, ok := claimset.(string); ok {
payload = val payload = val

View file

@ -232,7 +232,7 @@ func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderO
return responseOrder(res) return responseOrder(res)
} }
// GetOrder retrives an order identified by the given URL. // GetOrder retrieves an order identified by the given URL.
// For orders created with AuthorizeOrder, the url value is Order.URI. // For orders created with AuthorizeOrder, the url value is Order.URI.
// //
// If a caller needs to poll an order until its status is final, // If a caller needs to poll an order until its status is final,
@ -272,7 +272,7 @@ func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) {
case err != nil: case err != nil:
// Skip and retry. // Skip and retry.
case o.Status == StatusInvalid: case o.Status == StatusInvalid:
return nil, &OrderError{OrderURL: o.URI, Status: o.Status} return nil, &OrderError{OrderURL: o.URI, Status: o.Status, Problem: o.Error}
case o.Status == StatusReady || o.Status == StatusValid: case o.Status == StatusReady || o.Status == StatusValid:
return o, nil return o, nil
} }
@ -369,7 +369,7 @@ func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bu
} }
// The only acceptable status post finalize and WaitOrder is "valid". // The only acceptable status post finalize and WaitOrder is "valid".
if o.Status != StatusValid { if o.Status != StatusValid {
return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status} return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status, Problem: o.Error}
} }
crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle) crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle)
return crt, o.CertURL, err return crt, o.CertURL, err

View file

@ -7,6 +7,7 @@ package acme
import ( import (
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -55,6 +56,10 @@ var (
// ErrNoAccount indicates that the Client's key has not been registered with the CA. // ErrNoAccount indicates that the Client's key has not been registered with the CA.
ErrNoAccount = errors.New("acme: account does not exist") ErrNoAccount = errors.New("acme: account does not exist")
// errPreAuthorizationNotSupported indicates that the server does not
// support pre-authorization of identifiers.
errPreAuthorizationNotSupported = errors.New("acme: pre-authorization is not supported")
) )
// A Subproblem describes an ACME subproblem as reported in an Error. // A Subproblem describes an ACME subproblem as reported in an Error.
@ -149,13 +154,16 @@ func (a *AuthorizationError) Error() string {
// OrderError is returned from Client's order related methods. // OrderError is returned from Client's order related methods.
// It indicates the order is unusable and the clients should start over with // It indicates the order is unusable and the clients should start over with
// AuthorizeOrder. // AuthorizeOrder. A Problem description may be provided with details on
// what caused the order to become unusable.
// //
// The clients can still fetch the order object from CA using GetOrder // The clients can still fetch the order object from CA using GetOrder
// to inspect its state. // to inspect its state.
type OrderError struct { type OrderError struct {
OrderURL string OrderURL string
Status string Status string
// Problem is the error that occurred while processing the order.
Problem *Error
} }
func (oe *OrderError) Error() string { func (oe *OrderError) Error() string {
@ -288,7 +296,7 @@ type Directory struct {
// KeyChangeURL allows to perform account key rollover flow. // KeyChangeURL allows to perform account key rollover flow.
KeyChangeURL string KeyChangeURL string
// Term is a URI identifying the current terms of service. // Terms is a URI identifying the current terms of service.
Terms string Terms string
// Website is an HTTP or HTTPS URL locating a website // Website is an HTTP or HTTPS URL locating a website
@ -527,6 +535,16 @@ type Challenge struct {
// when this challenge was used. // when this challenge was used.
// The type of a non-nil value is *Error. // The type of a non-nil value is *Error.
Error error Error error
// Payload is the JSON-formatted payload that the client sends
// to the server to indicate it is ready to respond to the challenge.
// When unset, it defaults to an empty JSON object: {}.
// For most challenges, the client must not set Payload,
// see https://tools.ietf.org/html/rfc8555#section-7.5.1.
// Payload is used only for newer challenges (such as "device-attest-01")
// where the client must send additional data for the server to validate
// the challenge.
Payload json.RawMessage
} }
// wireChallenge is ACME JSON challenge representation. // wireChallenge is ACME JSON challenge representation.
@ -604,7 +622,7 @@ func (*certOptKey) privateCertOpt() {}
// //
// In TLS ChallengeCert methods, the template is also used as parent, // In TLS ChallengeCert methods, the template is also used as parent,
// resulting in a self-signed certificate. // resulting in a self-signed certificate.
// The DNSNames field of t is always overwritten for tls-sni challenge certs. // The DNSNames or IPAddresses fields of t are always overwritten for tls-alpn challenge certs.
func WithTemplate(t *x509.Certificate) CertOption { func WithTemplate(t *x509.Certificate) CertOption {
return (*certOptTemplate)(t) return (*certOptTemplate)(t)
} }

View file

@ -1,27 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.12
package acme
import "runtime/debug"
func init() {
// Set packageVersion if the binary was built in modules mode and x/crypto
// was not replaced with a different module.
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
for _, m := range info.Deps {
if m.Path != "golang.org/x/crypto" {
continue
}
if m.Replace == nil {
packageVersion = m.Version
}
break
}
}

View file

@ -6,7 +6,7 @@
// Argon2 was selected as the winner of the Password Hashing Competition and can // Argon2 was selected as the winner of the Password Hashing Competition and can
// be used to derive cryptographic keys from passwords. // be used to derive cryptographic keys from passwords.
// //
// For a detailed specification of Argon2 see [1]. // For a detailed specification of Argon2 see [argon2-specs.pdf].
// //
// If you aren't sure which function you need, use Argon2id (IDKey) and // If you aren't sure which function you need, use Argon2id (IDKey) and
// the parameter recommendations for your scenario. // the parameter recommendations for your scenario.
@ -17,7 +17,7 @@
// It uses data-independent memory access, which is preferred for password // It uses data-independent memory access, which is preferred for password
// hashing and password-based key derivation. Argon2i requires more passes over // hashing and password-based key derivation. Argon2i requires more passes over
// memory than Argon2id to protect from trade-off attacks. The recommended // memory than Argon2id to protect from trade-off attacks. The recommended
// parameters (taken from [2]) for non-interactive operations are time=3 and to // parameters (taken from [RFC 9106 Section 7.3]) for non-interactive operations are time=3 and to
// use the maximum available memory. // use the maximum available memory.
// //
// # Argon2id // # Argon2id
@ -27,11 +27,11 @@
// half of the first iteration over the memory and data-dependent memory access // half of the first iteration over the memory and data-dependent memory access
// for the rest. Argon2id is side-channel resistant and provides better brute- // for the rest. Argon2id is side-channel resistant and provides better brute-
// force cost savings due to time-memory tradeoffs than Argon2i. The recommended // force cost savings due to time-memory tradeoffs than Argon2i. The recommended
// parameters for non-interactive operations (taken from [2]) are time=1 and to // parameters for non-interactive operations (taken from [RFC 9106 Section 7.3]) are time=1 and to
// use the maximum available memory. // use the maximum available memory.
// //
// [1] https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf // [argon2-specs.pdf]: https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
// [2] https://tools.ietf.org/html/draft-irtf-cfrg-argon2-03#section-9.3 // [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
package argon2 package argon2
import ( import (
@ -59,7 +59,7 @@ const (
// //
// key := argon2.Key([]byte("some password"), salt, 3, 32*1024, 4, 32) // key := argon2.Key([]byte("some password"), salt, 3, 32*1024, 4, 32)
// //
// The draft RFC recommends[2] time=3, and memory=32*1024 is a sensible number. // [RFC 9106 Section 7.3] recommends time=3, and memory=32*1024 as a sensible number.
// If using that amount of memory (32 MB) is not possible in some contexts then // If using that amount of memory (32 MB) is not possible in some contexts then
// the time parameter can be increased to compensate. // the time parameter can be increased to compensate.
// //
@ -69,6 +69,8 @@ const (
// adjusted to the number of available CPUs. The cost parameters should be // adjusted to the number of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a // increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt. // good random salt.
//
// [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte { func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen) return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen)
} }
@ -83,7 +85,7 @@ func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint3
// //
// key := argon2.IDKey([]byte("some password"), salt, 1, 64*1024, 4, 32) // key := argon2.IDKey([]byte("some password"), salt, 1, 64*1024, 4, 32)
// //
// The draft RFC recommends[2] time=1, and memory=64*1024 is a sensible number. // [RFC 9106 Section 7.3] recommends time=1, and memory=64*1024 as a sensible number.
// If using that amount of memory (64 MB) is not possible in some contexts then // If using that amount of memory (64 MB) is not possible in some contexts then
// the time parameter can be increased to compensate. // the time parameter can be increased to compensate.
// //
@ -93,6 +95,8 @@ func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint3
// adjusted to the numbers of available CPUs. The cost parameters should be // adjusted to the numbers of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a // increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt. // good random salt.
//
// [RFC 9106 Section 7.3]: https://www.rfc-editor.org/rfc/rfc9106.html#section-7.3
func IDKey(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte { func IDKey(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2id, password, salt, nil, nil, time, memory, threads, keyLen) return deriveKey(argon2id, password, salt, nil, nil, time, memory, threads, keyLen)
} }

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,8 @@ import (
// XOF defines the interface to hash functions that // XOF defines the interface to hash functions that
// support arbitrary-length output. // support arbitrary-length output.
//
// New callers should prefer the standard library [hash.XOF].
type XOF interface { type XOF interface {
// Write absorbs more data into the hash's state. It panics if called // Write absorbs more data into the hash's state. It panics if called
// after Read. // after Read.
@ -47,6 +49,8 @@ const maxOutputLength = (1 << 32) * 64
// //
// A non-nil key turns the hash into a MAC. The key must between // A non-nil key turns the hash into a MAC. The key must between
// zero and 32 bytes long. // zero and 32 bytes long.
//
// The result can be safely interface-upgraded to [hash.XOF].
func NewXOF(size uint32, key []byte) (XOF, error) { func NewXOF(size uint32, key []byte) (XOF, error) {
if len(key) > Size { if len(key) > Size {
return nil, errKeySize return nil, errKeySize
@ -93,6 +97,10 @@ func (x *xof) Clone() XOF {
return &clone return &clone
} }
func (x *xof) BlockSize() int {
return x.d.BlockSize()
}
func (x *xof) Reset() { func (x *xof) Reset() {
x.cfg[0] = byte(Size) x.cfg[0] = byte(Size)
binary.LittleEndian.PutUint32(x.cfg[4:], uint32(Size)) // leaf length binary.LittleEndian.PutUint32(x.cfg[4:], uint32(Size)) // leaf length

11
vendor/golang.org/x/crypto/blake2b/go125.go generated vendored Normal file
View file

@ -0,0 +1,11 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.25
package blake2b
import "hash"
var _ hash.XOF = (*xof)(nil)

View file

@ -234,7 +234,7 @@ func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) {
// Identifiers with the low five bits set indicate high-tag-number format // Identifiers with the low five bits set indicate high-tag-number format
// (two or more octets), which we don't support. // (two or more octets), which we don't support.
if tag&0x1f == 0x1f { if tag&0x1f == 0x1f {
b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag) b.err = fmt.Errorf("cryptobyte: high-tag number identifier octets not supported: 0x%x", tag)
return return
} }
b.AddUint8(uint8(tag)) b.AddUint8(uint8(tag))

View file

@ -4,7 +4,7 @@
// Package asn1 contains supporting types for parsing and building ASN.1 // Package asn1 contains supporting types for parsing and building ASN.1
// messages with the cryptobyte package. // messages with the cryptobyte package.
package asn1 // import "golang.org/x/crypto/cryptobyte/asn1" package asn1
// Tag represents an ASN.1 identifier octet, consisting of a tag number // Tag represents an ASN.1 identifier octet, consisting of a tag number
// (indicating a type) and class (such as context-specific or constructed). // (indicating a type) and class (such as context-specific or constructed).

View file

@ -15,7 +15,7 @@
// //
// See the documentation and examples for the Builder and String types to get // See the documentation and examples for the Builder and String types to get
// started. // started.
package cryptobyte // import "golang.org/x/crypto/cryptobyte" package cryptobyte
// String represents a string of bytes. It provides methods for parsing // String represents a string of bytes. It provides methods for parsing
// fixed-length and length-prefixed values from it. // fixed-length and length-prefixed values from it.

View file

@ -5,7 +5,7 @@
// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses // Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses
// are signed messages attesting to the validity of a certificate for a small // are signed messages attesting to the validity of a certificate for a small
// period of time. This is used to manage revocation for X.509 certificates. // period of time. This is used to manage revocation for X.509 certificates.
package ocsp // import "golang.org/x/crypto/ocsp" package ocsp
import ( import (
"crypto" "crypto"

View file

@ -16,7 +16,7 @@ Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To
choose, you can pass the `New` functions from the different SHA packages to choose, you can pass the `New` functions from the different SHA packages to
pbkdf2.Key. pbkdf2.Key.
*/ */
package pbkdf2 // import "golang.org/x/crypto/pbkdf2" package pbkdf2
import ( import (
"crypto/hmac" "crypto/hmac"

View file

@ -1,62 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha3 implements the SHA-3 fixed-output-length hash functions and
// the SHAKE variable-output-length hash functions defined by FIPS-202.
//
// Both types of hash function use the "sponge" construction and the Keccak
// permutation. For a detailed specification see http://keccak.noekeon.org/
//
// # Guidance
//
// If you aren't sure what function you need, use SHAKE256 with at least 64
// bytes of output. The SHAKE instances are faster than the SHA3 instances;
// the latter have to allocate memory to conform to the hash.Hash interface.
//
// If you need a secret-key MAC (message authentication code), prepend the
// secret key to the input, hash with SHAKE256 and read at least 32 bytes of
// output.
//
// # Security strengths
//
// The SHA3-x (x equals 224, 256, 384, or 512) functions have a security
// strength against preimage attacks of x bits. Since they only produce "x"
// bits of output, their collision-resistance is only "x/2" bits.
//
// The SHAKE-256 and -128 functions have a generic security strength of 256 and
// 128 bits against all attacks, provided that at least 2x bits of their output
// is used. Requesting more than 64 or 32 bytes of output, respectively, does
// not increase the collision-resistance of the SHAKE functions.
//
// # The sponge construction
//
// A sponge builds a pseudo-random function from a public pseudo-random
// permutation, by applying the permutation to a state of "rate + capacity"
// bytes, but hiding "capacity" of the bytes.
//
// A sponge starts out with a zero state. To hash an input using a sponge, up
// to "rate" bytes of the input are XORed into the sponge's state. The sponge
// is then "full" and the permutation is applied to "empty" it. This process is
// repeated until all the input has been "absorbed". The input is then padded.
// The digest is "squeezed" from the sponge in the same way, except that output
// is copied out instead of input being XORed in.
//
// A sponge is parameterized by its generic security strength, which is equal
// to half its capacity; capacity + rate is equal to the permutation's width.
// Since the KeccakF-1600 permutation is 1600 bits (200 bytes) wide, this means
// that the security strength of a sponge instance is equal to (1600 - bitrate) / 2.
//
// # Recommendations
//
// The SHAKE functions are recommended for most new uses. They can produce
// output of arbitrary length. SHAKE256, with an output length of at least
// 64 bytes, provides 256-bit security against all attacks. The Keccak team
// recommends it for most applications upgrading from SHA2-512. (NIST chose a
// much stronger, but much slower, sponge instance for SHA3-512.)
//
// The SHA-3 functions are "drop-in" replacements for the SHA-2 functions.
// They produce output of the same length, with the same security strengths
// against all attacks. This means, in particular, that SHA3-256 only has
// 128-bit collision resistance, because its output length is 32 bytes.
package sha3 // import "golang.org/x/crypto/sha3"

View file

@ -2,96 +2,94 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package sha3 implements the SHA-3 hash algorithms and the SHAKE extendable
// output functions defined in FIPS 202.
//
// Most of this package is a wrapper around the crypto/sha3 package in the
// standard library. The only exception is the legacy Keccak hash functions.
package sha3 package sha3
// This file provides functions for creating instances of the SHA-3
// and SHAKE hash functions, as well as utility functions for hashing
// bytes.
import ( import (
"crypto/sha3"
"hash" "hash"
) )
// New224 creates a new SHA3-224 hash. // New224 creates a new SHA3-224 hash.
// Its generic security strength is 224 bits against preimage attacks, // Its generic security strength is 224 bits against preimage attacks,
// and 112 bits against collision attacks. // and 112 bits against collision attacks.
//
// It is a wrapper for the [sha3.New224] function in the standard library.
//
//go:fix inline
func New224() hash.Hash { func New224() hash.Hash {
if h := new224Asm(); h != nil { return sha3.New224()
return h
}
return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
} }
// New256 creates a new SHA3-256 hash. // New256 creates a new SHA3-256 hash.
// Its generic security strength is 256 bits against preimage attacks, // Its generic security strength is 256 bits against preimage attacks,
// and 128 bits against collision attacks. // and 128 bits against collision attacks.
//
// It is a wrapper for the [sha3.New256] function in the standard library.
//
//go:fix inline
func New256() hash.Hash { func New256() hash.Hash {
if h := new256Asm(); h != nil { return sha3.New256()
return h
}
return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
} }
// New384 creates a new SHA3-384 hash. // New384 creates a new SHA3-384 hash.
// Its generic security strength is 384 bits against preimage attacks, // Its generic security strength is 384 bits against preimage attacks,
// and 192 bits against collision attacks. // and 192 bits against collision attacks.
//
// It is a wrapper for the [sha3.New384] function in the standard library.
//
//go:fix inline
func New384() hash.Hash { func New384() hash.Hash {
if h := new384Asm(); h != nil { return sha3.New384()
return h
}
return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
} }
// New512 creates a new SHA3-512 hash. // New512 creates a new SHA3-512 hash.
// Its generic security strength is 512 bits against preimage attacks, // Its generic security strength is 512 bits against preimage attacks,
// and 256 bits against collision attacks. // and 256 bits against collision attacks.
//
// It is a wrapper for the [sha3.New512] function in the standard library.
//
//go:fix inline
func New512() hash.Hash { func New512() hash.Hash {
if h := new512Asm(); h != nil { return sha3.New512()
return h
}
return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
} }
// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
// Sum224 returns the SHA3-224 digest of the data. // Sum224 returns the SHA3-224 digest of the data.
func Sum224(data []byte) (digest [28]byte) { //
h := New224() // It is a wrapper for the [sha3.Sum224] function in the standard library.
h.Write(data) //
h.Sum(digest[:0]) //go:fix inline
return func Sum224(data []byte) [28]byte {
return sha3.Sum224(data)
} }
// Sum256 returns the SHA3-256 digest of the data. // Sum256 returns the SHA3-256 digest of the data.
func Sum256(data []byte) (digest [32]byte) { //
h := New256() // It is a wrapper for the [sha3.Sum256] function in the standard library.
h.Write(data) //
h.Sum(digest[:0]) //go:fix inline
return func Sum256(data []byte) [32]byte {
return sha3.Sum256(data)
} }
// Sum384 returns the SHA3-384 digest of the data. // Sum384 returns the SHA3-384 digest of the data.
func Sum384(data []byte) (digest [48]byte) { //
h := New384() // It is a wrapper for the [sha3.Sum384] function in the standard library.
h.Write(data) //
h.Sum(digest[:0]) //go:fix inline
return func Sum384(data []byte) [48]byte {
return sha3.Sum384(data)
} }
// Sum512 returns the SHA3-512 digest of the data. // Sum512 returns the SHA3-512 digest of the data.
func Sum512(data []byte) (digest [64]byte) { //
h := New512() // It is a wrapper for the [sha3.Sum512] function in the standard library.
h.Write(data) //
h.Sum(digest[:0]) //go:fix inline
return func Sum512(data []byte) [64]byte {
return sha3.Sum512(data)
} }

View file

@ -1,27 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !gc || purego || !s390x
package sha3
import (
"hash"
)
// new224Asm returns an assembly implementation of SHA3-224 if available,
// otherwise it returns nil.
func new224Asm() hash.Hash { return nil }
// new256Asm returns an assembly implementation of SHA3-256 if available,
// otherwise it returns nil.
func new256Asm() hash.Hash { return nil }
// new384Asm returns an assembly implementation of SHA3-384 if available,
// otherwise it returns nil.
func new384Asm() hash.Hash { return nil }
// new512Asm returns an assembly implementation of SHA3-512 if available,
// otherwise it returns nil.
func new512Asm() hash.Hash { return nil }

View file

@ -1,13 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 && !purego && gc
package sha3
// This function is implemented in keccakf_amd64.s.
//go:noescape
func keccakF1600(a *[25]uint64)

View file

@ -1,390 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 && !purego && gc
// This code was translated into a form compatible with 6a from the public
// domain sources at https://github.com/gvanas/KeccakCodePackage
// Offsets in state
#define _ba (0*8)
#define _be (1*8)
#define _bi (2*8)
#define _bo (3*8)
#define _bu (4*8)
#define _ga (5*8)
#define _ge (6*8)
#define _gi (7*8)
#define _go (8*8)
#define _gu (9*8)
#define _ka (10*8)
#define _ke (11*8)
#define _ki (12*8)
#define _ko (13*8)
#define _ku (14*8)
#define _ma (15*8)
#define _me (16*8)
#define _mi (17*8)
#define _mo (18*8)
#define _mu (19*8)
#define _sa (20*8)
#define _se (21*8)
#define _si (22*8)
#define _so (23*8)
#define _su (24*8)
// Temporary registers
#define rT1 AX
// Round vars
#define rpState DI
#define rpStack SP
#define rDa BX
#define rDe CX
#define rDi DX
#define rDo R8
#define rDu R9
#define rBa R10
#define rBe R11
#define rBi R12
#define rBo R13
#define rBu R14
#define rCa SI
#define rCe BP
#define rCi rBi
#define rCo rBo
#define rCu R15
#define MOVQ_RBI_RCE MOVQ rBi, rCe
#define XORQ_RT1_RCA XORQ rT1, rCa
#define XORQ_RT1_RCE XORQ rT1, rCe
#define XORQ_RBA_RCU XORQ rBa, rCu
#define XORQ_RBE_RCU XORQ rBe, rCu
#define XORQ_RDU_RCU XORQ rDu, rCu
#define XORQ_RDA_RCA XORQ rDa, rCa
#define XORQ_RDE_RCE XORQ rDe, rCe
#define mKeccakRound(iState, oState, rc, B_RBI_RCE, G_RT1_RCA, G_RT1_RCE, G_RBA_RCU, K_RT1_RCA, K_RT1_RCE, K_RBA_RCU, M_RT1_RCA, M_RT1_RCE, M_RBE_RCU, S_RDU_RCU, S_RDA_RCA, S_RDE_RCE) \
/* Prepare round */ \
MOVQ rCe, rDa; \
ROLQ $1, rDa; \
\
MOVQ _bi(iState), rCi; \
XORQ _gi(iState), rDi; \
XORQ rCu, rDa; \
XORQ _ki(iState), rCi; \
XORQ _mi(iState), rDi; \
XORQ rDi, rCi; \
\
MOVQ rCi, rDe; \
ROLQ $1, rDe; \
\
MOVQ _bo(iState), rCo; \
XORQ _go(iState), rDo; \
XORQ rCa, rDe; \
XORQ _ko(iState), rCo; \
XORQ _mo(iState), rDo; \
XORQ rDo, rCo; \
\
MOVQ rCo, rDi; \
ROLQ $1, rDi; \
\
MOVQ rCu, rDo; \
XORQ rCe, rDi; \
ROLQ $1, rDo; \
\
MOVQ rCa, rDu; \
XORQ rCi, rDo; \
ROLQ $1, rDu; \
\
/* Result b */ \
MOVQ _ba(iState), rBa; \
MOVQ _ge(iState), rBe; \
XORQ rCo, rDu; \
MOVQ _ki(iState), rBi; \
MOVQ _mo(iState), rBo; \
MOVQ _su(iState), rBu; \
XORQ rDe, rBe; \
ROLQ $44, rBe; \
XORQ rDi, rBi; \
XORQ rDa, rBa; \
ROLQ $43, rBi; \
\
MOVQ rBe, rCa; \
MOVQ rc, rT1; \
ORQ rBi, rCa; \
XORQ rBa, rT1; \
XORQ rT1, rCa; \
MOVQ rCa, _ba(oState); \
\
XORQ rDu, rBu; \
ROLQ $14, rBu; \
MOVQ rBa, rCu; \
ANDQ rBe, rCu; \
XORQ rBu, rCu; \
MOVQ rCu, _bu(oState); \
\
XORQ rDo, rBo; \
ROLQ $21, rBo; \
MOVQ rBo, rT1; \
ANDQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _bi(oState); \
\
NOTQ rBi; \
ORQ rBa, rBu; \
ORQ rBo, rBi; \
XORQ rBo, rBu; \
XORQ rBe, rBi; \
MOVQ rBu, _bo(oState); \
MOVQ rBi, _be(oState); \
B_RBI_RCE; \
\
/* Result g */ \
MOVQ _gu(iState), rBe; \
XORQ rDu, rBe; \
MOVQ _ka(iState), rBi; \
ROLQ $20, rBe; \
XORQ rDa, rBi; \
ROLQ $3, rBi; \
MOVQ _bo(iState), rBa; \
MOVQ rBe, rT1; \
ORQ rBi, rT1; \
XORQ rDo, rBa; \
MOVQ _me(iState), rBo; \
MOVQ _si(iState), rBu; \
ROLQ $28, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ga(oState); \
G_RT1_RCA; \
\
XORQ rDe, rBo; \
ROLQ $45, rBo; \
MOVQ rBi, rT1; \
ANDQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _ge(oState); \
G_RT1_RCE; \
\
XORQ rDi, rBu; \
ROLQ $61, rBu; \
MOVQ rBu, rT1; \
ORQ rBa, rT1; \
XORQ rBo, rT1; \
MOVQ rT1, _go(oState); \
\
ANDQ rBe, rBa; \
XORQ rBu, rBa; \
MOVQ rBa, _gu(oState); \
NOTQ rBu; \
G_RBA_RCU; \
\
ORQ rBu, rBo; \
XORQ rBi, rBo; \
MOVQ rBo, _gi(oState); \
\
/* Result k */ \
MOVQ _be(iState), rBa; \
MOVQ _gi(iState), rBe; \
MOVQ _ko(iState), rBi; \
MOVQ _mu(iState), rBo; \
MOVQ _sa(iState), rBu; \
XORQ rDi, rBe; \
ROLQ $6, rBe; \
XORQ rDo, rBi; \
ROLQ $25, rBi; \
MOVQ rBe, rT1; \
ORQ rBi, rT1; \
XORQ rDe, rBa; \
ROLQ $1, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ka(oState); \
K_RT1_RCA; \
\
XORQ rDu, rBo; \
ROLQ $8, rBo; \
MOVQ rBi, rT1; \
ANDQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _ke(oState); \
K_RT1_RCE; \
\
XORQ rDa, rBu; \
ROLQ $18, rBu; \
NOTQ rBo; \
MOVQ rBo, rT1; \
ANDQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _ki(oState); \
\
MOVQ rBu, rT1; \
ORQ rBa, rT1; \
XORQ rBo, rT1; \
MOVQ rT1, _ko(oState); \
\
ANDQ rBe, rBa; \
XORQ rBu, rBa; \
MOVQ rBa, _ku(oState); \
K_RBA_RCU; \
\
/* Result m */ \
MOVQ _ga(iState), rBe; \
XORQ rDa, rBe; \
MOVQ _ke(iState), rBi; \
ROLQ $36, rBe; \
XORQ rDe, rBi; \
MOVQ _bu(iState), rBa; \
ROLQ $10, rBi; \
MOVQ rBe, rT1; \
MOVQ _mi(iState), rBo; \
ANDQ rBi, rT1; \
XORQ rDu, rBa; \
MOVQ _so(iState), rBu; \
ROLQ $27, rBa; \
XORQ rBa, rT1; \
MOVQ rT1, _ma(oState); \
M_RT1_RCA; \
\
XORQ rDi, rBo; \
ROLQ $15, rBo; \
MOVQ rBi, rT1; \
ORQ rBo, rT1; \
XORQ rBe, rT1; \
MOVQ rT1, _me(oState); \
M_RT1_RCE; \
\
XORQ rDo, rBu; \
ROLQ $56, rBu; \
NOTQ rBo; \
MOVQ rBo, rT1; \
ORQ rBu, rT1; \
XORQ rBi, rT1; \
MOVQ rT1, _mi(oState); \
\
ORQ rBa, rBe; \
XORQ rBu, rBe; \
MOVQ rBe, _mu(oState); \
\
ANDQ rBa, rBu; \
XORQ rBo, rBu; \
MOVQ rBu, _mo(oState); \
M_RBE_RCU; \
\
/* Result s */ \
MOVQ _bi(iState), rBa; \
MOVQ _go(iState), rBe; \
MOVQ _ku(iState), rBi; \
XORQ rDi, rBa; \
MOVQ _ma(iState), rBo; \
ROLQ $62, rBa; \
XORQ rDo, rBe; \
MOVQ _se(iState), rBu; \
ROLQ $55, rBe; \
\
XORQ rDu, rBi; \
MOVQ rBa, rDu; \
XORQ rDe, rBu; \
ROLQ $2, rBu; \
ANDQ rBe, rDu; \
XORQ rBu, rDu; \
MOVQ rDu, _su(oState); \
\
ROLQ $39, rBi; \
S_RDU_RCU; \
NOTQ rBe; \
XORQ rDa, rBo; \
MOVQ rBe, rDa; \
ANDQ rBi, rDa; \
XORQ rBa, rDa; \
MOVQ rDa, _sa(oState); \
S_RDA_RCA; \
\
ROLQ $41, rBo; \
MOVQ rBi, rDe; \
ORQ rBo, rDe; \
XORQ rBe, rDe; \
MOVQ rDe, _se(oState); \
S_RDE_RCE; \
\
MOVQ rBo, rDi; \
MOVQ rBu, rDo; \
ANDQ rBu, rDi; \
ORQ rBa, rDo; \
XORQ rBi, rDi; \
XORQ rBo, rDo; \
MOVQ rDi, _si(oState); \
MOVQ rDo, _so(oState) \
// func keccakF1600(a *[25]uint64)
TEXT ·keccakF1600(SB), 0, $200-8
MOVQ a+0(FP), rpState
// Convert the user state into an internal state
NOTQ _be(rpState)
NOTQ _bi(rpState)
NOTQ _go(rpState)
NOTQ _ki(rpState)
NOTQ _mi(rpState)
NOTQ _sa(rpState)
// Execute the KeccakF permutation
MOVQ _ba(rpState), rCa
MOVQ _be(rpState), rCe
MOVQ _bu(rpState), rCu
XORQ _ga(rpState), rCa
XORQ _ge(rpState), rCe
XORQ _gu(rpState), rCu
XORQ _ka(rpState), rCa
XORQ _ke(rpState), rCe
XORQ _ku(rpState), rCu
XORQ _ma(rpState), rCa
XORQ _me(rpState), rCe
XORQ _mu(rpState), rCu
XORQ _sa(rpState), rCa
XORQ _se(rpState), rCe
MOVQ _si(rpState), rDi
MOVQ _so(rpState), rDo
XORQ _su(rpState), rCu
mKeccakRound(rpState, rpStack, $0x0000000000000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000000008082, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x800000000000808a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000080008000, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000808b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000080000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000080008081, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008009, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000008a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x0000000000000088, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x0000000080008009, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x000000008000000a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000008000808b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x800000000000008b, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000000008089, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008003, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000000008002, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000000080, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x000000000000800a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x800000008000000a, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x8000000080008081, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000000008080, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpState, rpStack, $0x0000000080000001, MOVQ_RBI_RCE, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBA_RCU, XORQ_RT1_RCA, XORQ_RT1_RCE, XORQ_RBE_RCU, XORQ_RDU_RCU, XORQ_RDA_RCA, XORQ_RDE_RCE)
mKeccakRound(rpStack, rpState, $0x8000000080008008, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP, NOP)
// Revert the internal state to the user state
NOTQ _be(rpState)
NOTQ _bi(rpState)
NOTQ _go(rpState)
NOTQ _ki(rpState)
NOTQ _mi(rpState)
NOTQ _sa(rpState)
RET

263
vendor/golang.org/x/crypto/sha3/legacy_hash.go generated vendored Normal file
View file

@ -0,0 +1,263 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// This implementation is only used for NewLegacyKeccak256 and
// NewLegacyKeccak512, which are not implemented by crypto/sha3.
// All other functions in this package are wrappers around crypto/sha3.
import (
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
"unsafe"
"golang.org/x/sys/cpu"
)
const (
dsbyteKeccak = 0b00000001
// rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
// bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
rateK256 = (1600 - 256) / 8
rateK512 = (1600 - 512) / 8
rateK1024 = (1600 - 1024) / 8
)
// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash {
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
}
// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash {
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
}
// spongeDirection indicates the direction bytes are flowing through the sponge.
type spongeDirection int
const (
// spongeAbsorbing indicates that the sponge is absorbing input.
spongeAbsorbing spongeDirection = iota
// spongeSqueezing indicates that the sponge is being squeezed.
spongeSqueezing
)
type state struct {
a [1600 / 8]byte // main state of the hash
// a[n:rate] is the buffer. If absorbing, it's the remaining space to XOR
// into before running the permutation. If squeezing, it's the remaining
// output to produce before running the permutation.
n, rate int
// dsbyte contains the "domain separation" bits and the first bit of
// the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the
// SHA-3 and SHAKE functions by appending bitstrings to the message.
// Using a little-endian bit-ordering convention, these are "01" for SHA-3
// and "1111" for SHAKE, or 00000010b and 00001111b, respectively. Then the
// padding rule from section 5.1 is applied to pad the message to a multiple
// of the rate, which involves adding a "1" bit, zero or more "0" bits, and
// a final "1" bit. We merge the first "1" bit from the padding into dsbyte,
// giving 00000110b (0x06) and 00011111b (0x1f).
// [1] http://csrc.nist.gov/publications/drafts/fips-202/fips_202_draft.pdf
// "Draft FIPS 202: SHA-3 Standard: Permutation-Based Hash and
// Extendable-Output Functions (May 2014)"
dsbyte byte
outputLen int // the default output size in bytes
state spongeDirection // whether the sponge is absorbing or squeezing
}
// BlockSize returns the rate of sponge underlying this hash function.
func (d *state) BlockSize() int { return d.rate }
// Size returns the output size of the hash function in bytes.
func (d *state) Size() int { return d.outputLen }
// Reset clears the internal state by zeroing the sponge state and
// the buffer indexes, and setting Sponge.state to absorbing.
func (d *state) Reset() {
// Zero the permutation's state.
for i := range d.a {
d.a[i] = 0
}
d.state = spongeAbsorbing
d.n = 0
}
func (d *state) clone() *state {
ret := *d
return &ret
}
// permute applies the KeccakF-1600 permutation.
func (d *state) permute() {
var a *[25]uint64
if cpu.IsBigEndian {
a = new([25]uint64)
for i := range a {
a[i] = binary.LittleEndian.Uint64(d.a[i*8:])
}
} else {
a = (*[25]uint64)(unsafe.Pointer(&d.a))
}
keccakF1600(a)
d.n = 0
if cpu.IsBigEndian {
for i := range a {
binary.LittleEndian.PutUint64(d.a[i*8:], a[i])
}
}
}
// pads appends the domain separation bits in dsbyte, applies
// the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute() {
// Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in the sponge because, if it were full,
// permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct.
d.a[d.n] ^= d.dsbyte
// This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte.
d.a[d.rate-1] ^= 0x80
// Apply the permutation
d.permute()
d.state = spongeSqueezing
}
// Write absorbs more data into the hash's state. It panics if any
// output has already been read.
func (d *state) Write(p []byte) (n int, err error) {
if d.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
n = len(p)
for len(p) > 0 {
x := subtle.XORBytes(d.a[d.n:d.rate], d.a[d.n:d.rate], p)
d.n += x
p = p[x:]
// If the sponge is full, apply the permutation.
if d.n == d.rate {
d.permute()
}
}
return
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (d *state) Read(out []byte) (n int, err error) {
// If we're still absorbing, pad and apply the permutation.
if d.state == spongeAbsorbing {
d.padAndPermute()
}
n = len(out)
// Now, do the squeezing.
for len(out) > 0 {
// Apply the permutation if we've squeezed the sponge dry.
if d.n == d.rate {
d.permute()
}
x := copy(out, d.a[d.n:d.rate])
d.n += x
out = out[x:]
}
return
}
// Sum applies padding to the hash state and then squeezes out the desired
// number of output bytes. It panics if any output has already been read.
func (d *state) Sum(in []byte) []byte {
if d.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Make a copy of the original hash so that caller can keep writing
// and summing.
dup := d.clone()
hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation
dup.Read(hash)
return append(in, hash...)
}
const (
magicKeccak = "sha\x0b"
// magic || rate || main state || n || sponge direction
marshaledSize = len(magicKeccak) + 1 + 200 + 1 + 1
)
func (d *state) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *state) AppendBinary(b []byte) ([]byte, error) {
switch d.dsbyte {
case dsbyteKeccak:
b = append(b, magicKeccak...)
default:
panic("unknown dsbyte")
}
// rate is at most 168, and n is at most rate.
b = append(b, byte(d.rate))
b = append(b, d.a[:]...)
b = append(b, byte(d.n), byte(d.state))
return b, nil
}
func (d *state) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize {
return errors.New("sha3: invalid hash state")
}
magic := string(b[:len(magicKeccak)])
b = b[len(magicKeccak):]
switch {
case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
default:
return errors.New("sha3: invalid hash state identifier")
}
rate := int(b[0])
b = b[1:]
if rate != d.rate {
return errors.New("sha3: invalid hash state function")
}
copy(d.a[:], b)
b = b[len(d.a):]
n, state := int(b[0]), spongeDirection(b[1])
if n > d.rate {
return errors.New("sha3: invalid hash state")
}
d.n = n
if state != spongeAbsorbing && state != spongeSqueezing {
return errors.New("sha3: invalid hash state")
}
d.state = state
return nil
}

View file

@ -2,10 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !amd64 || purego || !gc
package sha3 package sha3
// This implementation is only used for NewLegacyKeccak256 and
// NewLegacyKeccak512, which are not implemented by crypto/sha3.
// All other functions in this package are wrappers around crypto/sha3.
import "math/bits" import "math/bits"
// rc stores the round constants for use in the ι step. // rc stores the round constants for use in the ι step.

View file

@ -1,18 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.4
package sha3
import (
"crypto"
)
func init() {
crypto.RegisterHash(crypto.SHA3_224, New224)
crypto.RegisterHash(crypto.SHA3_256, New256)
crypto.RegisterHash(crypto.SHA3_384, New384)
crypto.RegisterHash(crypto.SHA3_512, New512)
}

View file

@ -1,197 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// spongeDirection indicates the direction bytes are flowing through the sponge.
type spongeDirection int
const (
// spongeAbsorbing indicates that the sponge is absorbing input.
spongeAbsorbing spongeDirection = iota
// spongeSqueezing indicates that the sponge is being squeezed.
spongeSqueezing
)
const (
// maxRate is the maximum size of the internal buffer. SHAKE-256
// currently needs the largest buffer.
maxRate = 168
)
type state struct {
// Generic sponge components.
a [25]uint64 // main state of the hash
buf []byte // points into storage
rate int // the number of bytes of state to use
// dsbyte contains the "domain separation" bits and the first bit of
// the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the
// SHA-3 and SHAKE functions by appending bitstrings to the message.
// Using a little-endian bit-ordering convention, these are "01" for SHA-3
// and "1111" for SHAKE, or 00000010b and 00001111b, respectively. Then the
// padding rule from section 5.1 is applied to pad the message to a multiple
// of the rate, which involves adding a "1" bit, zero or more "0" bits, and
// a final "1" bit. We merge the first "1" bit from the padding into dsbyte,
// giving 00000110b (0x06) and 00011111b (0x1f).
// [1] http://csrc.nist.gov/publications/drafts/fips-202/fips_202_draft.pdf
// "Draft FIPS 202: SHA-3 Standard: Permutation-Based Hash and
// Extendable-Output Functions (May 2014)"
dsbyte byte
storage storageBuf
// Specific to SHA-3 and SHAKE.
outputLen int // the default output size in bytes
state spongeDirection // whether the sponge is absorbing or squeezing
}
// BlockSize returns the rate of sponge underlying this hash function.
func (d *state) BlockSize() int { return d.rate }
// Size returns the output size of the hash function in bytes.
func (d *state) Size() int { return d.outputLen }
// Reset clears the internal state by zeroing the sponge state and
// the byte buffer, and setting Sponge.state to absorbing.
func (d *state) Reset() {
// Zero the permutation's state.
for i := range d.a {
d.a[i] = 0
}
d.state = spongeAbsorbing
d.buf = d.storage.asBytes()[:0]
}
func (d *state) clone() *state {
ret := *d
if ret.state == spongeAbsorbing {
ret.buf = ret.storage.asBytes()[:len(ret.buf)]
} else {
ret.buf = ret.storage.asBytes()[d.rate-cap(d.buf) : d.rate]
}
return &ret
}
// permute applies the KeccakF-1600 permutation. It handles
// any input-output buffering.
func (d *state) permute() {
switch d.state {
case spongeAbsorbing:
// If we're absorbing, we need to xor the input into the state
// before applying the permutation.
xorIn(d, d.buf)
d.buf = d.storage.asBytes()[:0]
keccakF1600(&d.a)
case spongeSqueezing:
// If we're squeezing, we need to apply the permutation before
// copying more output.
keccakF1600(&d.a)
d.buf = d.storage.asBytes()[:d.rate]
copyOut(d, d.buf)
}
}
// pads appends the domain separation bits in dsbyte, applies
// the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute(dsbyte byte) {
if d.buf == nil {
d.buf = d.storage.asBytes()[:0]
}
// Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in d.buf because, if it were full,
// permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct.
d.buf = append(d.buf, dsbyte)
zerosStart := len(d.buf)
d.buf = d.storage.asBytes()[:d.rate]
for i := zerosStart; i < d.rate; i++ {
d.buf[i] = 0
}
// This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte.
d.buf[d.rate-1] ^= 0x80
// Apply the permutation
d.permute()
d.state = spongeSqueezing
d.buf = d.storage.asBytes()[:d.rate]
copyOut(d, d.buf)
}
// Write absorbs more data into the hash's state. It panics if any
// output has already been read.
func (d *state) Write(p []byte) (written int, err error) {
if d.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
if d.buf == nil {
d.buf = d.storage.asBytes()[:0]
}
written = len(p)
for len(p) > 0 {
if len(d.buf) == 0 && len(p) >= d.rate {
// The fast path; absorb a full "rate" bytes of input and apply the permutation.
xorIn(d, p[:d.rate])
p = p[d.rate:]
keccakF1600(&d.a)
} else {
// The slow path; buffer the input until we can fill the sponge, and then xor it in.
todo := d.rate - len(d.buf)
if todo > len(p) {
todo = len(p)
}
d.buf = append(d.buf, p[:todo]...)
p = p[todo:]
// If the sponge is full, apply the permutation.
if len(d.buf) == d.rate {
d.permute()
}
}
}
return
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (d *state) Read(out []byte) (n int, err error) {
// If we're still absorbing, pad and apply the permutation.
if d.state == spongeAbsorbing {
d.padAndPermute(d.dsbyte)
}
n = len(out)
// Now, do the squeezing.
for len(out) > 0 {
n := copy(out, d.buf)
d.buf = d.buf[n:]
out = out[n:]
// Apply the permutation if we've squeezed the sponge dry.
if len(d.buf) == 0 {
d.permute()
}
}
return
}
// Sum applies padding to the hash state and then squeezes out the desired
// number of output bytes. It panics if any output has already been read.
func (d *state) Sum(in []byte) []byte {
if d.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Make a copy of the original hash so that caller can keep writing
// and summing.
dup := d.clone()
hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation
dup.Read(hash)
return append(in, hash...)
}

View file

@ -1,303 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package sha3
// This file contains code for using the 'compute intermediate
// message digest' (KIMD) and 'compute last message digest' (KLMD)
// instructions to compute SHA-3 and SHAKE hashes on IBM Z.
import (
"hash"
"golang.org/x/sys/cpu"
)
// codes represent 7-bit KIMD/KLMD function codes as defined in
// the Principles of Operation.
type code uint64
const (
// function codes for KIMD/KLMD
sha3_224 code = 32
sha3_256 = 33
sha3_384 = 34
sha3_512 = 35
shake_128 = 36
shake_256 = 37
nopad = 0x100
)
// kimd is a wrapper for the 'compute intermediate message digest' instruction.
// src must be a multiple of the rate for the given function code.
//
//go:noescape
func kimd(function code, chain *[200]byte, src []byte)
// klmd is a wrapper for the 'compute last message digest' instruction.
// src padding is handled by the instruction.
//
//go:noescape
func klmd(function code, chain *[200]byte, dst, src []byte)
type asmState struct {
a [200]byte // 1600 bit state
buf []byte // care must be taken to ensure cap(buf) is a multiple of rate
rate int // equivalent to block size
storage [3072]byte // underlying storage for buf
outputLen int // output length for full security
function code // KIMD/KLMD function code
state spongeDirection // whether the sponge is absorbing or squeezing
}
func newAsmState(function code) *asmState {
var s asmState
s.function = function
switch function {
case sha3_224:
s.rate = 144
s.outputLen = 28
case sha3_256:
s.rate = 136
s.outputLen = 32
case sha3_384:
s.rate = 104
s.outputLen = 48
case sha3_512:
s.rate = 72
s.outputLen = 64
case shake_128:
s.rate = 168
s.outputLen = 32
case shake_256:
s.rate = 136
s.outputLen = 64
default:
panic("sha3: unrecognized function code")
}
// limit s.buf size to a multiple of s.rate
s.resetBuf()
return &s
}
func (s *asmState) clone() *asmState {
c := *s
c.buf = c.storage[:len(s.buf):cap(s.buf)]
return &c
}
// copyIntoBuf copies b into buf. It will panic if there is not enough space to
// store all of b.
func (s *asmState) copyIntoBuf(b []byte) {
bufLen := len(s.buf)
s.buf = s.buf[:len(s.buf)+len(b)]
copy(s.buf[bufLen:], b)
}
// resetBuf points buf at storage, sets the length to 0 and sets cap to be a
// multiple of the rate.
func (s *asmState) resetBuf() {
max := (cap(s.storage) / s.rate) * s.rate
s.buf = s.storage[:0:max]
}
// Write (via the embedded io.Writer interface) adds more data to the running hash.
// It never returns an error.
func (s *asmState) Write(b []byte) (int, error) {
if s.state != spongeAbsorbing {
panic("sha3: Write after Read")
}
length := len(b)
for len(b) > 0 {
if len(s.buf) == 0 && len(b) >= cap(s.buf) {
// Hash the data directly and push any remaining bytes
// into the buffer.
remainder := len(b) % s.rate
kimd(s.function, &s.a, b[:len(b)-remainder])
if remainder != 0 {
s.copyIntoBuf(b[len(b)-remainder:])
}
return length, nil
}
if len(s.buf) == cap(s.buf) {
// flush the buffer
kimd(s.function, &s.a, s.buf)
s.buf = s.buf[:0]
}
// copy as much as we can into the buffer
n := len(b)
if len(b) > cap(s.buf)-len(s.buf) {
n = cap(s.buf) - len(s.buf)
}
s.copyIntoBuf(b[:n])
b = b[n:]
}
return length, nil
}
// Read squeezes an arbitrary number of bytes from the sponge.
func (s *asmState) Read(out []byte) (n int, err error) {
// The 'compute last message digest' instruction only stores the digest
// at the first operand (dst) for SHAKE functions.
if s.function != shake_128 && s.function != shake_256 {
panic("sha3: can only call Read for SHAKE functions")
}
n = len(out)
// need to pad if we were absorbing
if s.state == spongeAbsorbing {
s.state = spongeSqueezing
// write hash directly into out if possible
if len(out)%s.rate == 0 {
klmd(s.function, &s.a, out, s.buf) // len(out) may be 0
s.buf = s.buf[:0]
return
}
// write hash into buffer
max := cap(s.buf)
if max > len(out) {
max = (len(out)/s.rate)*s.rate + s.rate
}
klmd(s.function, &s.a, s.buf[:max], s.buf)
s.buf = s.buf[:max]
}
for len(out) > 0 {
// flush the buffer
if len(s.buf) != 0 {
c := copy(out, s.buf)
out = out[c:]
s.buf = s.buf[c:]
continue
}
// write hash directly into out if possible
if len(out)%s.rate == 0 {
klmd(s.function|nopad, &s.a, out, nil)
return
}
// write hash into buffer
s.resetBuf()
if cap(s.buf) > len(out) {
s.buf = s.buf[:(len(out)/s.rate)*s.rate+s.rate]
}
klmd(s.function|nopad, &s.a, s.buf, nil)
}
return
}
// Sum appends the current hash to b and returns the resulting slice.
// It does not change the underlying hash state.
func (s *asmState) Sum(b []byte) []byte {
if s.state != spongeAbsorbing {
panic("sha3: Sum after Read")
}
// Copy the state to preserve the original.
a := s.a
// Hash the buffer. Note that we don't clear it because we
// aren't updating the state.
switch s.function {
case sha3_224, sha3_256, sha3_384, sha3_512:
klmd(s.function, &a, nil, s.buf)
return append(b, a[:s.outputLen]...)
case shake_128, shake_256:
d := make([]byte, s.outputLen, 64)
klmd(s.function, &a, d, s.buf)
return append(b, d[:s.outputLen]...)
default:
panic("sha3: unknown function")
}
}
// Reset resets the Hash to its initial state.
func (s *asmState) Reset() {
for i := range s.a {
s.a[i] = 0
}
s.resetBuf()
s.state = spongeAbsorbing
}
// Size returns the number of bytes Sum will return.
func (s *asmState) Size() int {
return s.outputLen
}
// BlockSize returns the hash's underlying block size.
// The Write method must be able to accept any amount
// of data, but it may operate more efficiently if all writes
// are a multiple of the block size.
func (s *asmState) BlockSize() int {
return s.rate
}
// Clone returns a copy of the ShakeHash in its current state.
func (s *asmState) Clone() ShakeHash {
return s.clone()
}
// new224Asm returns an assembly implementation of SHA3-224 if available,
// otherwise it returns nil.
func new224Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_224)
}
return nil
}
// new256Asm returns an assembly implementation of SHA3-256 if available,
// otherwise it returns nil.
func new256Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_256)
}
return nil
}
// new384Asm returns an assembly implementation of SHA3-384 if available,
// otherwise it returns nil.
func new384Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_384)
}
return nil
}
// new512Asm returns an assembly implementation of SHA3-512 if available,
// otherwise it returns nil.
func new512Asm() hash.Hash {
if cpu.S390X.HasSHA3 {
return newAsmState(sha3_512)
}
return nil
}
// newShake128Asm returns an assembly implementation of SHAKE-128 if available,
// otherwise it returns nil.
func newShake128Asm() ShakeHash {
if cpu.S390X.HasSHA3 {
return newAsmState(shake_128)
}
return nil
}
// newShake256Asm returns an assembly implementation of SHAKE-256 if available,
// otherwise it returns nil.
func newShake256Asm() ShakeHash {
if cpu.S390X.HasSHA3 {
return newAsmState(shake_256)
}
return nil
}

View file

@ -1,33 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
// func kimd(function code, chain *[200]byte, src []byte)
TEXT ·kimd(SB), NOFRAME|NOSPLIT, $0-40
MOVD function+0(FP), R0
MOVD chain+8(FP), R1
LMG src+16(FP), R2, R3 // R2=base, R3=len
continue:
WORD $0xB93E0002 // KIMD --, R2
BVS continue // continue if interrupted
MOVD $0, R0 // reset R0 for pre-go1.8 compilers
RET
// func klmd(function code, chain *[200]byte, dst, src []byte)
TEXT ·klmd(SB), NOFRAME|NOSPLIT, $0-64
// TODO: SHAKE support
MOVD function+0(FP), R0
MOVD chain+8(FP), R1
LMG dst+16(FP), R2, R3 // R2=base, R3=len
LMG src+40(FP), R4, R5 // R4=base, R5=len
continue:
WORD $0xB93F0024 // KLMD R2, R4
BVS continue // continue if interrupted
MOVD $0, R0 // reset R0 for pre-go1.8 compilers
RET

View file

@ -4,19 +4,8 @@
package sha3 package sha3
// This file defines the ShakeHash interface, and provides
// functions for creating SHAKE and cSHAKE instances, as well as utility
// functions for hashing bytes to arbitrary-length output.
//
//
// SHAKE implementation is based on FIPS PUB 202 [1]
// cSHAKE implementations is based on NIST SP 800-185 [2]
//
// [1] https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf
// [2] https://doi.org/10.6028/NIST.SP.800-185
import ( import (
"encoding/binary" "crypto/sha3"
"hash" "hash"
"io" "io"
) )
@ -29,7 +18,7 @@ type ShakeHash interface {
hash.Hash hash.Hash
// Read reads more output from the hash; reading affects the hash's // Read reads more output from the hash; reading affects the hash's
// state. (ShakeHash.Read is thus very different from Hash.Sum) // state. (ShakeHash.Read is thus very different from Hash.Sum.)
// It never returns an error, but subsequent calls to Write or Sum // It never returns an error, but subsequent calls to Write or Sum
// will panic. // will panic.
io.Reader io.Reader
@ -38,97 +27,18 @@ type ShakeHash interface {
Clone() ShakeHash Clone() ShakeHash
} }
// cSHAKE specific context
type cshakeState struct {
*state // SHA-3 state context and Read/Write operations
// initBlock is the cSHAKE specific initialization set of bytes. It is initialized
// by newCShake function and stores concatenation of N followed by S, encoded
// by the method specified in 3.3 of [1].
// It is stored here in order for Reset() to be able to put context into
// initial state.
initBlock []byte
}
// Consts for configuring initial SHA-3 state
const (
dsbyteShake = 0x1f
dsbyteCShake = 0x04
rate128 = 168
rate256 = 136
)
func bytepad(input []byte, w int) []byte {
// leftEncode always returns max 9 bytes
buf := make([]byte, 0, 9+len(input)+w)
buf = append(buf, leftEncode(uint64(w))...)
buf = append(buf, input...)
padlen := w - (len(buf) % w)
return append(buf, make([]byte, padlen)...)
}
func leftEncode(value uint64) []byte {
var b [9]byte
binary.BigEndian.PutUint64(b[1:], value)
// Trim all but last leading zero bytes
i := byte(1)
for i < 8 && b[i] == 0 {
i++
}
// Prepend number of encoded bytes
b[i-1] = 9 - i
return b[i-1:]
}
func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}
// leftEncode returns max 9 bytes
c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
c.initBlock = append(c.initBlock, leftEncode(uint64(len(N)*8))...)
c.initBlock = append(c.initBlock, N...)
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...)
c.initBlock = append(c.initBlock, S...)
c.Write(bytepad(c.initBlock, c.rate))
return &c
}
// Reset resets the hash to initial state.
func (c *cshakeState) Reset() {
c.state.Reset()
c.Write(bytepad(c.initBlock, c.rate))
}
// Clone returns copy of a cSHAKE context within its current state.
func (c *cshakeState) Clone() ShakeHash {
b := make([]byte, len(c.initBlock))
copy(b, c.initBlock)
return &cshakeState{state: c.clone(), initBlock: b}
}
// Clone returns copy of SHAKE context within its current state.
func (c *state) Clone() ShakeHash {
return c.clone()
}
// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash. // NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
// Its generic security strength is 128 bits against all attacks if at // Its generic security strength is 128 bits against all attacks if at
// least 32 bytes of its output are used. // least 32 bytes of its output are used.
func NewShake128() ShakeHash { func NewShake128() ShakeHash {
if h := newShake128Asm(); h != nil { return &shakeWrapper{sha3.NewSHAKE128(), 32, false, sha3.NewSHAKE128}
return h
}
return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
} }
// NewShake256 creates a new SHAKE256 variable-output-length ShakeHash. // NewShake256 creates a new SHAKE256 variable-output-length ShakeHash.
// Its generic security strength is 256 bits against all attacks if // Its generic security strength is 256 bits against all attacks if
// at least 64 bytes of its output are used. // at least 64 bytes of its output are used.
func NewShake256() ShakeHash { func NewShake256() ShakeHash {
if h := newShake256Asm(); h != nil { return &shakeWrapper{sha3.NewSHAKE256(), 64, false, sha3.NewSHAKE256}
return h
}
return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
} }
// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash, // NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@ -138,10 +48,9 @@ func NewShake256() ShakeHash {
// computations on same input with different S yield unrelated outputs. // computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake128. // When N and S are both empty, this is equivalent to NewShake128.
func NewCShake128(N, S []byte) ShakeHash { func NewCShake128(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 { return &shakeWrapper{sha3.NewCSHAKE128(N, S), 32, false, func() *sha3.SHAKE {
return NewShake128() return sha3.NewCSHAKE128(N, S)
} }}
return newCShake(N, S, rate128, 32, dsbyteCShake)
} }
// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash, // NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@ -151,10 +60,9 @@ func NewCShake128(N, S []byte) ShakeHash {
// computations on same input with different S yield unrelated outputs. // computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake256. // When N and S are both empty, this is equivalent to NewShake256.
func NewCShake256(N, S []byte) ShakeHash { func NewCShake256(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 { return &shakeWrapper{sha3.NewCSHAKE256(N, S), 64, false, func() *sha3.SHAKE {
return NewShake256() return sha3.NewCSHAKE256(N, S)
} }}
return newCShake(N, S, rate256, 64, dsbyteCShake)
} }
// ShakeSum128 writes an arbitrary-length digest of data into hash. // ShakeSum128 writes an arbitrary-length digest of data into hash.
@ -170,3 +78,42 @@ func ShakeSum256(hash, data []byte) {
h.Write(data) h.Write(data)
h.Read(hash) h.Read(hash)
} }
// shakeWrapper adds the Size, Sum, and Clone methods to a sha3.SHAKE
// to implement the ShakeHash interface.
type shakeWrapper struct {
*sha3.SHAKE
outputLen int
squeezing bool
newSHAKE func() *sha3.SHAKE
}
func (w *shakeWrapper) Read(p []byte) (n int, err error) {
w.squeezing = true
return w.SHAKE.Read(p)
}
func (w *shakeWrapper) Clone() ShakeHash {
s := w.newSHAKE()
b, err := w.MarshalBinary()
if err != nil {
panic(err) // unreachable
}
if err := s.UnmarshalBinary(b); err != nil {
panic(err) // unreachable
}
return &shakeWrapper{s, w.outputLen, w.squeezing, w.newSHAKE}
}
func (w *shakeWrapper) Size() int { return w.outputLen }
func (w *shakeWrapper) Sum(b []byte) []byte {
if w.squeezing {
panic("sha3: Sum after Read")
}
out := make([]byte, w.outputLen)
// Clone the state so that we don't affect future Write calls.
s := w.Clone()
s.Read(out)
return append(b, out...)
}

View file

@ -1,19 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !gc || purego || !s390x
package sha3
// newShake128Asm returns an assembly implementation of SHAKE-128 if available,
// otherwise it returns nil.
func newShake128Asm() ShakeHash {
return nil
}
// newShake256Asm returns an assembly implementation of SHAKE-256 if available,
// otherwise it returns nil.
func newShake256Asm() ShakeHash {
return nil
}

View file

@ -1,23 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (!amd64 && !386 && !ppc64le) || purego
package sha3
// A storageBuf is an aligned array of maxRate bytes.
type storageBuf [maxRate]byte
func (b *storageBuf) asBytes() *[maxRate]byte {
return (*[maxRate]byte)(b)
}
var (
xorIn = xorInGeneric
copyOut = copyOutGeneric
xorInUnaligned = xorInGeneric
copyOutUnaligned = copyOutGeneric
)
const xorImplementationUnaligned = "generic"

View file

@ -1,28 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
import "encoding/binary"
// xorInGeneric xors the bytes in buf into the state; it
// makes no non-portable assumptions about memory layout
// or alignment.
func xorInGeneric(d *state, buf []byte) {
n := len(buf) / 8
for i := 0; i < n; i++ {
a := binary.LittleEndian.Uint64(buf)
d.a[i] ^= a
buf = buf[8:]
}
}
// copyOutGeneric copies uint64s to a byte buffer.
func copyOutGeneric(d *state, b []byte) {
for i := 0; len(b) >= 8; i++ {
binary.LittleEndian.PutUint64(b, d.a[i])
b = b[8:]
}
}

View file

@ -1,66 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (amd64 || 386 || ppc64le) && !purego
package sha3
import "unsafe"
// A storageBuf is an aligned array of maxRate bytes.
type storageBuf [maxRate / 8]uint64
func (b *storageBuf) asBytes() *[maxRate]byte {
return (*[maxRate]byte)(unsafe.Pointer(b))
}
// xorInUnaligned uses unaligned reads and writes to update d.a to contain d.a
// XOR buf.
func xorInUnaligned(d *state, buf []byte) {
n := len(buf)
bw := (*[maxRate / 8]uint64)(unsafe.Pointer(&buf[0]))[: n/8 : n/8]
if n >= 72 {
d.a[0] ^= bw[0]
d.a[1] ^= bw[1]
d.a[2] ^= bw[2]
d.a[3] ^= bw[3]
d.a[4] ^= bw[4]
d.a[5] ^= bw[5]
d.a[6] ^= bw[6]
d.a[7] ^= bw[7]
d.a[8] ^= bw[8]
}
if n >= 104 {
d.a[9] ^= bw[9]
d.a[10] ^= bw[10]
d.a[11] ^= bw[11]
d.a[12] ^= bw[12]
}
if n >= 136 {
d.a[13] ^= bw[13]
d.a[14] ^= bw[14]
d.a[15] ^= bw[15]
d.a[16] ^= bw[16]
}
if n >= 144 {
d.a[17] ^= bw[17]
}
if n >= 168 {
d.a[18] ^= bw[18]
d.a[19] ^= bw[19]
d.a[20] ^= bw[20]
}
}
func copyOutUnaligned(d *state, buf []byte) {
ab := (*[maxRate]uint8)(unsafe.Pointer(&d.a[0]))
copy(buf, ab[:])
}
var (
xorIn = xorInUnaligned
copyOut = copyOutUnaligned
)
const xorImplementationUnaligned = "unaligned"

4
vendor/golang.org/x/mod/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the in the documentation and/or other materials provided with the
distribution. distribution.
* Neither the name of Google Inc. nor the names of its * Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from contributors may be used to endorse or promote products derived from
this software without specific prior written permission. this software without specific prior written permission.

View file

@ -96,10 +96,11 @@ package module
// Changes to the semantics in this file require approval from rsc. // Changes to the semantics in this file require approval from rsc.
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"path" "path"
"sort" "slices"
"strings" "strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
@ -260,7 +261,7 @@ func modPathOK(r rune) bool {
// importPathOK reports whether r can appear in a package import path element. // importPathOK reports whether r can appear in a package import path element.
// //
// Import paths are intermediate between module paths and file paths: we allow // Import paths are intermediate between module paths and file paths: we
// disallow characters that would be confusing or ambiguous as arguments to // disallow characters that would be confusing or ambiguous as arguments to
// 'go get' (such as '@' and ' ' ), but allow certain characters that are // 'go get' (such as '@' and ' ' ), but allow certain characters that are
// otherwise-unambiguous on the command line and historically used for some // otherwise-unambiguous on the command line and historically used for some
@ -657,17 +658,15 @@ func CanonicalVersion(v string) string {
// optionally followed by a tie-breaking suffix introduced by a slash character, // optionally followed by a tie-breaking suffix introduced by a slash character,
// like in "v0.0.1/go.mod". // like in "v0.0.1/go.mod".
func Sort(list []Version) { func Sort(list []Version) {
sort.Slice(list, func(i, j int) bool { slices.SortFunc(list, func(i, j Version) int {
mi := list[i] if i.Path != j.Path {
mj := list[j] return strings.Compare(i.Path, j.Path)
if mi.Path != mj.Path {
return mi.Path < mj.Path
} }
// To help go.sum formatting, allow version/file. // To help go.sum formatting, allow version/file.
// Compare semver prefix by semver rules, // Compare semver prefix by semver rules,
// file by string order. // file by string order.
vi := mi.Version vi := i.Version
vj := mj.Version vj := j.Version
var fi, fj string var fi, fj string
if k := strings.Index(vi, "/"); k >= 0 { if k := strings.Index(vi, "/"); k >= 0 {
vi, fi = vi[:k], vi[k:] vi, fi = vi[:k], vi[k:]
@ -676,9 +675,9 @@ func Sort(list []Version) {
vj, fj = vj[:k], vj[k:] vj, fj = vj[:k], vj[k:]
} }
if vi != vj { if vi != vj {
return semver.Compare(vi, vj) < 0 return semver.Compare(vi, vj)
} }
return fi < fj return cmp.Compare(fi, fj)
}) })
} }
@ -803,8 +802,8 @@ func MatchPrefixPatterns(globs, target string) bool {
for globs != "" { for globs != "" {
// Extract next non-empty glob in comma-separated list. // Extract next non-empty glob in comma-separated list.
var glob string var glob string
if i := strings.Index(globs, ","); i >= 0 { if before, after, ok := strings.Cut(globs, ","); ok {
glob, globs = globs[:i], globs[i+1:] glob, globs = before, after
} else { } else {
glob, globs = globs, "" glob, globs = globs, ""
} }

View file

@ -22,7 +22,10 @@
// as shorthands for vMAJOR.0.0 and vMAJOR.MINOR.0. // as shorthands for vMAJOR.0.0 and vMAJOR.MINOR.0.
package semver package semver
import "sort" import (
"slices"
"strings"
)
// parsed returns the parsed form of a semantic version string. // parsed returns the parsed form of a semantic version string.
type parsed struct { type parsed struct {
@ -42,8 +45,8 @@ func IsValid(v string) bool {
// Canonical returns the canonical formatting of the semantic version v. // Canonical returns the canonical formatting of the semantic version v.
// It fills in any missing .MINOR or .PATCH and discards build metadata. // It fills in any missing .MINOR or .PATCH and discards build metadata.
// Two semantic versions compare equal only if their canonical formattings // Two semantic versions compare equal only if their canonical formatting
// are identical strings. // is an identical string.
// The canonical invalid semantic version is the empty string. // The canonical invalid semantic version is the empty string.
func Canonical(v string) string { func Canonical(v string) string {
p, ok := parse(v) p, ok := parse(v)
@ -154,19 +157,22 @@ func Max(v, w string) string {
// ByVersion implements [sort.Interface] for sorting semantic version strings. // ByVersion implements [sort.Interface] for sorting semantic version strings.
type ByVersion []string type ByVersion []string
func (vs ByVersion) Len() int { return len(vs) } func (vs ByVersion) Len() int { return len(vs) }
func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] } func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] }
func (vs ByVersion) Less(i, j int) bool { func (vs ByVersion) Less(i, j int) bool { return compareVersion(vs[i], vs[j]) < 0 }
cmp := Compare(vs[i], vs[j])
if cmp != 0 { // Sort sorts a list of semantic version strings using [Compare] and falls back
return cmp < 0 // to use [strings.Compare] if both versions are considered equal.
} func Sort(list []string) {
return vs[i] < vs[j] slices.SortFunc(list, compareVersion)
} }
// Sort sorts a list of semantic version strings using [ByVersion]. func compareVersion(a, b string) int {
func Sort(list []string) { cmp := Compare(a, b)
sort.Sort(ByVersion(list)) if cmp != 0 {
return cmp
}
return strings.Compare(a, b)
} }
func parse(v string) (p parsed, ok bool) { func parse(v string) (p parsed, ok bool) {

4
vendor/golang.org/x/net/LICENSE generated vendored
View file

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the in the documentation and/or other materials provided with the
distribution. distribution.
* Neither the name of Google Inc. nor the names of its * Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from contributors may be used to endorse or promote products derived from
this software without specific prior written permission. this software without specific prior written permission.

File diff suppressed because it is too large Load diff

View file

@ -78,16 +78,11 @@ example, to process each anchor node in depth-first order:
if err != nil { if err != nil {
// ... // ...
} }
var f func(*html.Node) for n := range doc.Descendants() {
f = func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "a" { if n.Type == html.ElementNode && n.Data == "a" {
// Do something with n... // Do something with n...
} }
for c := n.FirstChild; c != nil; c = c.NextSibling {
f(c)
}
} }
f(doc)
The relevant specifications include: The relevant specifications include:
https://html.spec.whatwg.org/multipage/syntax.html and https://html.spec.whatwg.org/multipage/syntax.html and

View file

@ -87,7 +87,7 @@ func parseDoctype(s string) (n *Node, quirks bool) {
} }
} }
if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" && if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" &&
strings.ToLower(lastAttr.Val) == "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd" { strings.EqualFold(lastAttr.Val, "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd") {
quirks = true quirks = true
} }
} }

View file

@ -299,7 +299,7 @@ func escape(w writer, s string) error {
case '\r': case '\r':
esc = "&#13;" esc = "&#13;"
default: default:
panic("unrecognized escape character") panic("html: unrecognized escape character")
} }
s = s[i+1:] s = s[i+1:]
if _, err := w.WriteString(esc); err != nil { if _, err := w.WriteString(esc); err != nil {

View file

@ -40,8 +40,7 @@ func htmlIntegrationPoint(n *Node) bool {
if n.Data == "annotation-xml" { if n.Data == "annotation-xml" {
for _, a := range n.Attr { for _, a := range n.Attr {
if a.Key == "encoding" { if a.Key == "encoding" {
val := strings.ToLower(a.Val) if strings.EqualFold(a.Val, "text/html") || strings.EqualFold(a.Val, "application/xhtml+xml") {
if val == "text/html" || val == "application/xhtml+xml" {
return true return true
} }
} }

56
vendor/golang.org/x/net/html/iter.go generated vendored Normal file
View file

@ -0,0 +1,56 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.23
package html
import "iter"
// Ancestors returns an iterator over the ancestors of n, starting with n.Parent.
//
// Mutating a Node or its parents while iterating may have unexpected results.
func (n *Node) Ancestors() iter.Seq[*Node] {
_ = n.Parent // eager nil check
return func(yield func(*Node) bool) {
for p := n.Parent; p != nil && yield(p); p = p.Parent {
}
}
}
// ChildNodes returns an iterator over the immediate children of n,
// starting with n.FirstChild.
//
// Mutating a Node or its children while iterating may have unexpected results.
func (n *Node) ChildNodes() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
for c := n.FirstChild; c != nil && yield(c); c = c.NextSibling {
}
}
}
// Descendants returns an iterator over all nodes recursively beneath
// n, excluding n itself. Nodes are visited in depth-first preorder.
//
// Mutating a Node or its descendants while iterating may have unexpected results.
func (n *Node) Descendants() iter.Seq[*Node] {
_ = n.FirstChild // eager nil check
return func(yield func(*Node) bool) {
n.descendants(yield)
}
}
func (n *Node) descendants(yield func(*Node) bool) bool {
for c := range n.ChildNodes() {
if !yield(c) || !c.descendants(yield) {
return false
}
}
return true
}

View file

@ -11,6 +11,7 @@ import (
// A NodeType is the type of a Node. // A NodeType is the type of a Node.
type NodeType uint32 type NodeType uint32
//go:generate stringer -type NodeType
const ( const (
ErrorNode NodeType = iota ErrorNode NodeType = iota
TextNode TextNode
@ -38,6 +39,10 @@ var scopeMarker = Node{Type: scopeMarkerNode}
// that it looks like "a<b" rather than "a&lt;b". For element nodes, DataAtom // that it looks like "a<b" rather than "a&lt;b". For element nodes, DataAtom
// is the atom for Data, or zero if Data is not a known tag name. // is the atom for Data, or zero if Data is not a known tag name.
// //
// Node trees may be navigated using the link fields (Parent,
// FirstChild, and so on) or a range loop over iterators such as
// [Node.Descendants].
//
// An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace. // An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace.
// Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and // Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and
// "svg" is short for "http://www.w3.org/2000/svg". // "svg" is short for "http://www.w3.org/2000/svg".

31
vendor/golang.org/x/net/html/nodetype_string.go generated vendored Normal file
View file

@ -0,0 +1,31 @@
// Code generated by "stringer -type NodeType"; DO NOT EDIT.
package html
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ErrorNode-0]
_ = x[TextNode-1]
_ = x[DocumentNode-2]
_ = x[ElementNode-3]
_ = x[CommentNode-4]
_ = x[DoctypeNode-5]
_ = x[RawNode-6]
_ = x[scopeMarkerNode-7]
}
const _NodeType_name = "ErrorNodeTextNodeDocumentNodeElementNodeCommentNodeDoctypeNodeRawNodescopeMarkerNode"
var _NodeType_index = [...]uint8{0, 9, 17, 29, 40, 51, 62, 69, 84}
func (i NodeType) String() string {
idx := int(i) - 0
if i < 0 || idx >= len(_NodeType_index)-1 {
return "NodeType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _NodeType_name[_NodeType_index[idx]:_NodeType_index[idx+1]]
}

View file

@ -136,7 +136,7 @@ func (p *parser) indexOfElementInScope(s scope, matchTags ...a.Atom) int {
return -1 return -1
} }
default: default:
panic("unreachable") panic(fmt.Sprintf("html: internal error: indexOfElementInScope unknown scope: %d", s))
} }
} }
switch s { switch s {
@ -179,7 +179,7 @@ func (p *parser) clearStackToContext(s scope) {
return return
} }
default: default:
panic("unreachable") panic(fmt.Sprintf("html: internal error: clearStackToContext unknown scope: %d", s))
} }
} }
} }
@ -231,7 +231,14 @@ func (p *parser) addChild(n *Node) {
} }
if n.Type == ElementNode { if n.Type == ElementNode {
p.oe = append(p.oe, n) p.insertOpenElement(n)
}
}
func (p *parser) insertOpenElement(n *Node) {
p.oe = append(p.oe, n)
if len(p.oe) > 512 {
panic("html: open stack of elements exceeds 512 nodes")
} }
} }
@ -810,7 +817,7 @@ func afterHeadIM(p *parser) bool {
p.im = inFramesetIM p.im = inFramesetIM
return true return true
case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title: case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title:
p.oe = append(p.oe, p.head) p.insertOpenElement(p.head)
defer p.oe.remove(p.head) defer p.oe.remove(p.head)
return inHeadIM(p) return inHeadIM(p)
case a.Head: case a.Head:
@ -840,6 +847,10 @@ func afterHeadIM(p *parser) bool {
p.parseImpliedToken(StartTagToken, a.Body, a.Body.String()) p.parseImpliedToken(StartTagToken, a.Body, a.Body.String())
p.framesetOK = true p.framesetOK = true
if p.tok.Type == ErrorToken {
// Stop parsing.
return true
}
return false return false
} }
@ -920,7 +931,7 @@ func inBodyIM(p *parser) bool {
p.addElement() p.addElement()
p.im = inFramesetIM p.im = inFramesetIM
return true return true
case a.Address, a.Article, a.Aside, a.Blockquote, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Main, a.Menu, a.Nav, a.Ol, a.P, a.Section, a.Summary, a.Ul: case a.Address, a.Article, a.Aside, a.Blockquote, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Main, a.Menu, a.Nav, a.Ol, a.P, a.Search, a.Section, a.Summary, a.Ul:
p.popUntil(buttonScope, a.P) p.popUntil(buttonScope, a.P)
p.addElement() p.addElement()
case a.H1, a.H2, a.H3, a.H4, a.H5, a.H6: case a.H1, a.H2, a.H3, a.H4, a.H5, a.H6:
@ -1031,7 +1042,7 @@ func inBodyIM(p *parser) bool {
if p.tok.DataAtom == a.Input { if p.tok.DataAtom == a.Input {
for _, t := range p.tok.Attr { for _, t := range p.tok.Attr {
if t.Key == "type" { if t.Key == "type" {
if strings.ToLower(t.Val) == "hidden" { if strings.EqualFold(t.Val, "hidden") {
// Skip setting framesetOK = false // Skip setting framesetOK = false
return true return true
} }
@ -1132,7 +1143,7 @@ func inBodyIM(p *parser) bool {
return false return false
} }
return true return true
case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Main, a.Menu, a.Nav, a.Ol, a.Pre, a.Section, a.Summary, a.Ul: case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dialog, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Main, a.Menu, a.Nav, a.Ol, a.Pre, a.Search, a.Section, a.Summary, a.Ul:
p.popUntil(defaultScope, p.tok.DataAtom) p.popUntil(defaultScope, p.tok.DataAtom)
case a.Form: case a.Form:
if p.oe.contains(a.Template) { if p.oe.contains(a.Template) {
@ -1459,7 +1470,7 @@ func inTableIM(p *parser) bool {
return inHeadIM(p) return inHeadIM(p)
case a.Input: case a.Input:
for _, t := range p.tok.Attr { for _, t := range p.tok.Attr {
if t.Key == "type" && strings.ToLower(t.Val) == "hidden" { if t.Key == "type" && strings.EqualFold(t.Val, "hidden") {
p.addElement() p.addElement()
p.oe.pop() p.oe.pop()
return true return true
@ -1674,7 +1685,7 @@ func inTableBodyIM(p *parser) bool {
return inTableIM(p) return inTableIM(p)
} }
// Section 12.2.6.4.14. // Section 13.2.6.4.14.
func inRowIM(p *parser) bool { func inRowIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken: case StartTagToken:
@ -1686,7 +1697,9 @@ func inRowIM(p *parser) bool {
p.im = inCellIM p.im = inCellIM
return true return true
case a.Caption, a.Col, a.Colgroup, a.Tbody, a.Tfoot, a.Thead, a.Tr: case a.Caption, a.Col, a.Colgroup, a.Tbody, a.Tfoot, a.Thead, a.Tr:
if p.popUntil(tableScope, a.Tr) { if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM p.im = inTableBodyIM
return false return false
} }
@ -1696,22 +1709,28 @@ func inRowIM(p *parser) bool {
case EndTagToken: case EndTagToken:
switch p.tok.DataAtom { switch p.tok.DataAtom {
case a.Tr: case a.Tr:
if p.popUntil(tableScope, a.Tr) { if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM p.im = inTableBodyIM
return true return true
} }
// Ignore the token. // Ignore the token.
return true return true
case a.Table: case a.Table:
if p.popUntil(tableScope, a.Tr) { if p.elementInScope(tableScope, a.Tr) {
p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM p.im = inTableBodyIM
return false return false
} }
// Ignore the token. // Ignore the token.
return true return true
case a.Tbody, a.Tfoot, a.Thead: case a.Tbody, a.Tfoot, a.Thead:
if p.elementInScope(tableScope, p.tok.DataAtom) { if p.elementInScope(tableScope, p.tok.DataAtom) && p.elementInScope(tableScope, a.Tr) {
p.parseImpliedToken(EndTagToken, a.Tr, a.Tr.String()) p.clearStackToContext(tableRowScope)
p.oe.pop()
p.im = inTableBodyIM
return false return false
} }
// Ignore the token. // Ignore the token.
@ -2218,16 +2237,20 @@ func parseForeignContent(p *parser) bool {
p.acknowledgeSelfClosingTag() p.acknowledgeSelfClosingTag()
} }
case EndTagToken: case EndTagToken:
if strings.EqualFold(p.oe[len(p.oe)-1].Data, p.tok.Data) {
p.oe = p.oe[:len(p.oe)-1]
return true
}
for i := len(p.oe) - 1; i >= 0; i-- { for i := len(p.oe) - 1; i >= 0; i-- {
if p.oe[i].Namespace == "" {
return p.im(p)
}
if strings.EqualFold(p.oe[i].Data, p.tok.Data) { if strings.EqualFold(p.oe[i].Data, p.tok.Data) {
p.oe = p.oe[:i] p.oe = p.oe[:i]
return true
}
if i > 0 && p.oe[i-1].Namespace == "" {
break break
} }
} }
return true return p.im(p)
default: default:
// Ignore the token. // Ignore the token.
} }
@ -2308,9 +2331,13 @@ func (p *parser) parseCurrentToken() {
} }
} }
func (p *parser) parse() error { func (p *parser) parse() (err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = fmt.Errorf("%s", panicErr)
}
}()
// Iterate until EOF. Any other error will cause an early return. // Iterate until EOF. Any other error will cause an early return.
var err error
for err != io.EOF { for err != io.EOF {
// CDATA sections are allowed only in foreign content. // CDATA sections are allowed only in foreign content.
n := p.oe.top() n := p.oe.top()
@ -2339,6 +2366,8 @@ func (p *parser) parse() error {
// <tag>s. Conversely, explicit <tag>s in r's data can be silently dropped, // <tag>s. Conversely, explicit <tag>s in r's data can be silently dropped,
// with no corresponding node in the resulting tree. // with no corresponding node in the resulting tree.
// //
// Parse will reject HTML that is nested deeper than 512 elements.
//
// The input is assumed to be UTF-8 encoded. // The input is assumed to be UTF-8 encoded.
func Parse(r io.Reader) (*Node, error) { func Parse(r io.Reader) (*Node, error) {
return ParseWithOptions(r) return ParseWithOptions(r)

View file

@ -184,7 +184,7 @@ func render1(w writer, n *Node) error {
return err return err
} }
// Add initial newline where there is danger of a newline beging ignored. // Add initial newline where there is danger of a newline being ignored.
if c := n.FirstChild; c != nil && c.Type == TextNode && strings.HasPrefix(c.Data, "\n") { if c := n.FirstChild; c != nil && c.Type == TextNode && strings.HasPrefix(c.Data, "\n") {
switch n.Data { switch n.Data {
case "pre", "listing", "textarea": case "pre", "listing", "textarea":

View file

@ -839,8 +839,22 @@ func (z *Tokenizer) readStartTag() TokenType {
if raw { if raw {
z.rawTag = strings.ToLower(string(z.buf[z.data.start:z.data.end])) z.rawTag = strings.ToLower(string(z.buf[z.data.start:z.data.end]))
} }
// Look for a self-closing token like "<br/>". // Look for a self-closing token (e.g. <br/>).
if z.err == nil && z.buf[z.raw.end-2] == '/' { //
// Originally, we did this by just checking that the last character of the
// tag (ignoring the closing bracket) was a solidus (/) character, but this
// is not always accurate.
//
// We need to be careful that we don't misinterpret a non-self-closing tag
// as self-closing, as can happen if the tag contains unquoted attribute
// values (i.e. <p a=/>).
//
// To avoid this, we check that the last non-bracket character of the tag
// (z.raw.end-2) isn't the same character as the last non-quote character of
// the last attribute of the tag (z.pendingAttr[1].end-1), if the tag has
// attributes.
nAttrs := len(z.attr)
if z.err == nil && z.buf[z.raw.end-2] == '/' && (nAttrs == 0 || z.raw.end-2 != z.attr[nAttrs-1][1].end-1) {
return SelfClosingTagToken return SelfClosingTagToken
} }
return StartTagToken return StartTagToken

View file

@ -8,8 +8,8 @@ package http2
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"net"
"net/http" "net/http"
"sync" "sync"
) )
@ -158,7 +158,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) {
// This code decides which ones live or die. // This code decides which ones live or die.
// The return value used is whether c was used. // The return value used is whether c was used.
// c is never closed. // c is never closed.
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) { func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
p.mu.Lock() p.mu.Lock()
for _, cc := range p.conns[key] { for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() { if cc.CanTakeNewRequest() {
@ -194,8 +194,8 @@ type addConnCall struct {
err error err error
} }
func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) { func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
cc, err := t.NewClientConn(tc) cc, err := t.NewClientConn(nc)
p := c.p p := c.p
p.mu.Lock() p.mu.Lock()

20
vendor/golang.org/x/net/http2/client_priority_go126.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.27
package http2
import "net/http"
// Support for go.dev/issue/75500 is added in Go 1.27. In case anyone uses
// x/net with versions before Go 1.27, we return true here so that their write
// scheduler will still be the round-robin write scheduler rather than the RFC
// 9218 write scheduler. That way, older users of Go will not see a sudden
// change of behavior just from importing x/net.
//
// TODO(nsh): remove this file after x/net go.mod is at Go 1.27.
func clientPriorityDisabled(_ *http.Server) bool {
return true
}

13
vendor/golang.org/x/net/http2/client_priority_go127.go generated vendored Normal file
View file

@ -0,0 +1,13 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.27
package http2
import "net/http"
func clientPriorityDisabled(s *http.Server) bool {
return s.DisableClientPriority
}

169
vendor/golang.org/x/net/http2/config.go generated vendored Normal file
View file

@ -0,0 +1,169 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"math"
"net/http"
"time"
)
// http2Config is a package-internal version of net/http.HTTP2Config.
//
// http.HTTP2Config was added in Go 1.24.
// When running with a version of net/http that includes HTTP2Config,
// we merge the configuration with the fields in Transport or Server
// to produce an http2Config.
//
// Zero valued fields in http2Config are interpreted as in the
// net/http.HTTPConfig documentation.
//
// Precedence order for reconciling configurations is:
//
// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
// - Otherwise use the http2.{Server.Transport} value.
// - If the resulting value is zero or out of range, use a default.
type http2Config struct {
MaxConcurrentStreams uint32
StrictMaxConcurrentRequests bool
MaxDecoderHeaderTableSize uint32
MaxEncoderHeaderTableSize uint32
MaxReadFrameSize uint32
MaxUploadBufferPerConnection int32
MaxUploadBufferPerStream int32
SendPingTimeout time.Duration
PingTimeout time.Duration
WriteByteTimeout time.Duration
PermitProhibitedCipherSuites bool
CountError func(errType string)
}
// configFromServer merges configuration settings from
// net/http.Server.HTTP2Config and http2.Server.
func configFromServer(h1 *http.Server, h2 *Server) http2Config {
conf := http2Config{
MaxConcurrentStreams: h2.MaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
CountError: h2.CountError,
}
fillNetHTTPConfig(&conf, h1.HTTP2)
setConfigDefaults(&conf, true)
return conf
}
// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2
// (the net/http Transport).
func configFromTransport(h2 *Transport) http2Config {
conf := http2Config{
StrictMaxConcurrentRequests: h2.StrictMaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
}
// Unlike most config fields, where out-of-range values revert to the default,
// Transport.MaxReadFrameSize clips.
if conf.MaxReadFrameSize < minMaxFrameSize {
conf.MaxReadFrameSize = minMaxFrameSize
} else if conf.MaxReadFrameSize > maxFrameSize {
conf.MaxReadFrameSize = maxFrameSize
}
if h2.t1 != nil {
fillNetHTTPConfig(&conf, h2.t1.HTTP2)
}
setConfigDefaults(&conf, false)
return conf
}
func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
if *v < minval || *v > maxval {
*v = defval
}
}
func setConfigDefaults(conf *http2Config, server bool) {
setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
if server {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
}
if server {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
}
setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
}
// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
func adjustHTTP1MaxHeaderSize(n int64) int64 {
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return n + typicalHeaders*perFieldOverhead
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if http2ConfigStrictMaxConcurrentRequests(h2) {
conf.StrictMaxConcurrentRequests = true
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}

15
vendor/golang.org/x/net/http2/config_go125.go generated vendored Normal file
View file

@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return false
}

15
vendor/golang.org/x/net/http2/config_go126.go generated vendored Normal file
View file

@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return h2.StrictMaxConcurrentRequests
}

View file

@ -11,11 +11,13 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"slices"
"strings" "strings"
"sync" "sync"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpsfv"
) )
const frameHeaderLen = 9 const frameHeaderLen = 9
@ -23,40 +25,43 @@ const frameHeaderLen = 9
var padZeros = make([]byte, 255) // zeros for padding var padZeros = make([]byte, 255) // zeros for padding
// A FrameType is a registered frame type as defined in // A FrameType is a registered frame type as defined in
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 // https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 and other future
// RFCs.
type FrameType uint8 type FrameType uint8
const ( const (
FrameData FrameType = 0x0 FrameData FrameType = 0x0
FrameHeaders FrameType = 0x1 FrameHeaders FrameType = 0x1
FramePriority FrameType = 0x2 FramePriority FrameType = 0x2
FrameRSTStream FrameType = 0x3 FrameRSTStream FrameType = 0x3
FrameSettings FrameType = 0x4 FrameSettings FrameType = 0x4
FramePushPromise FrameType = 0x5 FramePushPromise FrameType = 0x5
FramePing FrameType = 0x6 FramePing FrameType = 0x6
FrameGoAway FrameType = 0x7 FrameGoAway FrameType = 0x7
FrameWindowUpdate FrameType = 0x8 FrameWindowUpdate FrameType = 0x8
FrameContinuation FrameType = 0x9 FrameContinuation FrameType = 0x9
FramePriorityUpdate FrameType = 0x10
) )
var frameName = map[FrameType]string{ var frameNames = [...]string{
FrameData: "DATA", FrameData: "DATA",
FrameHeaders: "HEADERS", FrameHeaders: "HEADERS",
FramePriority: "PRIORITY", FramePriority: "PRIORITY",
FrameRSTStream: "RST_STREAM", FrameRSTStream: "RST_STREAM",
FrameSettings: "SETTINGS", FrameSettings: "SETTINGS",
FramePushPromise: "PUSH_PROMISE", FramePushPromise: "PUSH_PROMISE",
FramePing: "PING", FramePing: "PING",
FrameGoAway: "GOAWAY", FrameGoAway: "GOAWAY",
FrameWindowUpdate: "WINDOW_UPDATE", FrameWindowUpdate: "WINDOW_UPDATE",
FrameContinuation: "CONTINUATION", FrameContinuation: "CONTINUATION",
FramePriorityUpdate: "PRIORITY_UPDATE",
} }
func (t FrameType) String() string { func (t FrameType) String() string {
if s, ok := frameName[t]; ok { if int(t) < len(frameNames) {
return s return frameNames[t]
} }
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", t)
} }
// Flags is a bitmask of HTTP/2 flags. // Flags is a bitmask of HTTP/2 flags.
@ -124,22 +129,23 @@ var flagName = map[FrameType]map[Flags]string{
// might be 0). // might be 0).
type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error)
var frameParsers = map[FrameType]frameParser{ var frameParsers = [...]frameParser{
FrameData: parseDataFrame, FrameData: parseDataFrame,
FrameHeaders: parseHeadersFrame, FrameHeaders: parseHeadersFrame,
FramePriority: parsePriorityFrame, FramePriority: parsePriorityFrame,
FrameRSTStream: parseRSTStreamFrame, FrameRSTStream: parseRSTStreamFrame,
FrameSettings: parseSettingsFrame, FrameSettings: parseSettingsFrame,
FramePushPromise: parsePushPromise, FramePushPromise: parsePushPromise,
FramePing: parsePingFrame, FramePing: parsePingFrame,
FrameGoAway: parseGoAwayFrame, FrameGoAway: parseGoAwayFrame,
FrameWindowUpdate: parseWindowUpdateFrame, FrameWindowUpdate: parseWindowUpdateFrame,
FrameContinuation: parseContinuationFrame, FrameContinuation: parseContinuationFrame,
FramePriorityUpdate: parsePriorityUpdateFrame,
} }
func typeFrameParser(t FrameType) frameParser { func typeFrameParser(t FrameType) frameParser {
if f := frameParsers[t]; f != nil { if int(t) < len(frameParsers) {
return f return frameParsers[t]
} }
return parseUnknownFrame return parseUnknownFrame
} }
@ -225,6 +231,11 @@ var fhBytes = sync.Pool{
}, },
} }
func invalidHTTP1LookingFrameHeader() FrameHeader {
fh, _ := readFrameHeader(make([]byte, frameHeaderLen), strings.NewReader("HTTP/1.1 "))
return fh
}
// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader. // ReadFrameHeader reads 9 bytes from r and returns a FrameHeader.
// Most users should use Framer.ReadFrame instead. // Most users should use Framer.ReadFrame instead.
func ReadFrameHeader(r io.Reader) (FrameHeader, error) { func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
@ -275,6 +286,8 @@ type Framer struct {
// lastHeaderStream is non-zero if the last frame was an // lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION. // unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32 lastHeaderStream uint32
// lastFrameType holds the type of the last frame for verifying frame order.
lastFrameType FrameType
maxReadSize uint32 maxReadSize uint32
headerBuf [frameHeaderLen]byte headerBuf [frameHeaderLen]byte
@ -342,7 +355,7 @@ func (fr *Framer) maxHeaderListSize() uint32 {
func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) { func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
// Write the FrameHeader. // Write the FrameHeader.
f.wbuf = append(f.wbuf[:0], f.wbuf = append(f.wbuf[:0],
0, // 3 bytes of length, filled in in endWrite 0, // 3 bytes of length, filled in endWrite
0, 0,
0, 0,
byte(ftype), byte(ftype),
@ -483,30 +496,47 @@ func terminalReadFrameError(err error) bool {
return err != nil return err != nil
} }
// ReadFrame reads a single frame. The returned Frame is only valid // ReadFrameHeader reads the header of the next frame.
// until the next call to ReadFrame. // It reads the 9-byte fixed frame header, and does not read any portion of the
// frame payload. The caller is responsible for consuming the payload, either
// with ReadFrameForHeader or directly from the Framer's io.Reader.
// //
// If the frame is larger than previously set with SetMaxReadFrameSize, the // If the frame is larger than previously set with SetMaxReadFrameSize, it
// returned error is ErrFrameTooLarge. Other errors may be of type // returns the frame header and ErrFrameTooLarge.
// ConnectionError, StreamError, or anything else from the underlying
// reader.
// //
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID // If the returned FrameHeader.StreamID is non-zero, it indicates the stream
// indicates the stream responsible for the error. // responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) { func (fr *Framer) ReadFrameHeader() (FrameHeader, error) {
fr.errDetail = nil fr.errDetail = nil
fh, err := readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return fh, err
}
if fh.Length > fr.maxReadSize {
if fh == invalidHTTP1LookingFrameHeader() {
return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge)
}
return fh, ErrFrameTooLarge
}
if err := fr.checkFrameOrder(fh); err != nil {
return fh, err
}
return fh, nil
}
// ReadFrameForHeader reads the payload for the frame with the given FrameHeader.
//
// It behaves identically to ReadFrame, other than not checking the maximum
// frame size.
func (fr *Framer) ReadFrameForHeader(fh FrameHeader) (Frame, error) {
if fr.lastFrame != nil { if fr.lastFrame != nil {
fr.lastFrame.invalidate() fr.lastFrame.invalidate()
} }
fh, err := readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return nil, err
}
if fh.Length > fr.maxReadSize {
return nil, ErrFrameTooLarge
}
payload := fr.getReadBuf(fh.Length) payload := fr.getReadBuf(fh.Length)
if _, err := io.ReadFull(fr.r, payload); err != nil { if _, err := io.ReadFull(fr.r, payload); err != nil {
if fh == invalidHTTP1LookingFrameHeader() {
return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", err)
}
return nil, err return nil, err
} }
f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
@ -516,9 +546,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
} }
return nil, err return nil, err
} }
if err := fr.checkFrameOrder(f); err != nil { fr.lastFrame = f
return nil, err
}
if fr.logReads { if fr.logReads {
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f)) fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f))
} }
@ -528,6 +556,24 @@ func (fr *Framer) ReadFrame() (Frame, error) {
return f, nil return f, nil
} }
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame or ReadFrameBodyForHeader.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
fh, err := fr.ReadFrameHeader()
if err != nil {
return nil, err
}
return fr.ReadFrameForHeader(fh)
}
// connError returns ConnectionError(code) but first // connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it // stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug // to the peer before hanging up on them. This might help others debug
@ -540,20 +586,19 @@ func (fr *Framer) connError(code ErrCode, reason string) error {
// checkFrameOrder reports an error if f is an invalid frame to return // checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and // next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous. // CONTINUATION frames are contiguous.
func (fr *Framer) checkFrameOrder(f Frame) error { func (fr *Framer) checkFrameOrder(fh FrameHeader) error {
last := fr.lastFrame lastType := fr.lastFrameType
fr.lastFrame = f fr.lastFrameType = fh.Type
if fr.AllowIllegalReads { if fr.AllowIllegalReads {
return nil return nil
} }
fh := f.Header()
if fr.lastHeaderStream != 0 { if fr.lastHeaderStream != 0 {
if fh.Type != FrameContinuation { if fh.Type != FrameContinuation {
return fr.connError(ErrCodeProtocol, return fr.connError(ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID, fh.Type, fh.StreamID,
last.Header().Type, fr.lastHeaderStream)) lastType, fr.lastHeaderStream))
} }
if fh.StreamID != fr.lastHeaderStream { if fh.StreamID != fr.lastHeaderStream {
return fr.connError(ErrCodeProtocol, return fr.connError(ErrCodeProtocol,
@ -1141,7 +1186,41 @@ type PriorityFrame struct {
PriorityParam PriorityParam
} }
// PriorityParam are the stream prioritzation parameters. // defaultRFC9218Priority determines what priority we should use as the default
// value.
//
// According to RFC 9218, by default, streams should be given an urgency of 3
// and should be non-incremental. However, making streams non-incremental by
// default would be a huge change to our historical behavior where we would
// round-robin writes across streams. When streams are non-incremental, we
// would process streams of the same urgency one-by-one to completion instead.
//
// To avoid such a sudden change which might break some HTTP/2 users, this
// function allows the caller to specify whether they can actually use the
// default value as specified in RFC 9218. If not, this function will return a
// priority value where streams are incremental by default instead: effectively
// a round-robin between stream of the same urgency.
//
// As an example, a server might not be able to use the RFC 9218 default value
// when it's not sure that the client it is serving is aware of RFC 9218.
func defaultRFC9218Priority(canUseDefault bool) PriorityParam {
if canUseDefault {
return PriorityParam{
urgency: 3,
incremental: 0,
}
}
return PriorityParam{
urgency: 3,
incremental: 1,
}
}
// Note that HTTP/2 has had two different prioritization schemes, and
// PriorityParam struct below is a superset of both schemes. The exported
// symbols are from RFC 7540 and the non-exported ones are from RFC 9218.
// PriorityParam are the stream prioritization parameters.
type PriorityParam struct { type PriorityParam struct {
// StreamDep is a 31-bit stream identifier for the // StreamDep is a 31-bit stream identifier for the
// stream that this stream depends on. Zero means no // stream that this stream depends on. Zero means no
@ -1156,6 +1235,20 @@ type PriorityParam struct {
// the spec, "Add one to the value to obtain a weight between // the spec, "Add one to the value to obtain a weight between
// 1 and 256." // 1 and 256."
Weight uint8 Weight uint8
// "The urgency (u) parameter value is Integer (see Section 3.3.1 of
// [STRUCTURED-FIELDS]), between 0 and 7 inclusive, in descending order of
// priority. The default is 3."
urgency uint8
// "The incremental (i) parameter value is Boolean (see Section 3.3.6 of
// [STRUCTURED-FIELDS]). It indicates if an HTTP response can be processed
// incrementally, i.e., provide some meaningful output as chunks of the
// response arrive."
//
// We use uint8 (i.e. 0 is false, 1 is true) instead of bool so we can
// avoid unnecessary type conversions and because either type takes 1 byte.
incremental uint8
} }
func (p PriorityParam) IsZero() bool { func (p PriorityParam) IsZero() bool {
@ -1204,6 +1297,74 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
return f.endWrite() return f.endWrite()
} }
// PriorityUpdateFrame is a PRIORITY_UPDATE frame as described in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
type PriorityUpdateFrame struct {
FrameHeader
Priority string
PrioritizedStreamID uint32
}
func parseRFC9218Priority(s string, canUseDefault bool) (p PriorityParam, ok bool) {
p = defaultRFC9218Priority(canUseDefault)
ok = httpsfv.ParseDictionary(s, func(key, val, _ string) {
switch key {
case "u":
if u, ok := httpsfv.ParseInteger(val); ok && u >= 0 && u <= 7 {
p.urgency = uint8(u)
}
case "i":
if i, ok := httpsfv.ParseBoolean(val); ok {
if i {
p.incremental = 1
} else {
p.incremental = 0
}
}
}
})
if !ok {
return defaultRFC9218Priority(canUseDefault), ok
}
return p, true
}
func parsePriorityUpdateFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if fh.StreamID != 0 {
countError("frame_priority_update_non_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with non-zero stream ID"}
}
if len(payload) < 4 {
countError("frame_priority_update_bad_length")
return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY_UPDATE frame payload size was %d; want at least 4", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
streamID := v & 0x7fffffff // mask off high bit
if streamID == 0 {
countError("frame_priority_update_prioritizing_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with prioritized stream ID of zero"}
}
return &PriorityUpdateFrame{
FrameHeader: fh,
PrioritizedStreamID: streamID,
Priority: string(payload[4:]),
}, nil
}
// WritePriorityUpdate writes a PRIORITY_UPDATE frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *Framer) WritePriorityUpdate(streamID uint32, priority string) error {
if !validStreamID(streamID) && !f.AllowIllegalWrites {
return errStreamID
}
f.startWrite(FramePriorityUpdate, 0, 0)
f.writeUint32(streamID)
f.writeBytes([]byte(priority))
return f.endWrite()
}
// A RSTStreamFrame allows for abnormal termination of a stream. // A RSTStreamFrame allows for abnormal termination of a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
type RSTStreamFrame struct { type RSTStreamFrame struct {
@ -1485,12 +1646,29 @@ func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
return mh.Fields return mh.Fields
} }
func (mh *MetaHeadersFrame) rfc9218Priority(priorityAware bool) (p PriorityParam, priorityAwareAfter, hasIntermediary bool) {
var s string
for _, field := range mh.Fields {
if field.Name == "priority" {
s = field.Value
priorityAware = true
}
if slices.Contains([]string{"via", "forwarded", "x-forwarded-for"}, field.Name) {
hasIntermediary = true
}
}
// No need to check for ok. parseRFC9218Priority will return a default
// value if there is no priority field or if the field cannot be parsed.
p, _ = parseRFC9218Priority(s, priorityAware && !hasIntermediary)
return p, priorityAware, hasIntermediary
}
func (mh *MetaHeadersFrame) checkPseudos() error { func (mh *MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool var isRequest, isResponse bool
pf := mh.PseudoFields() pf := mh.PseudoFields()
for i, hf := range pf { for i, hf := range pf {
switch hf.Name { switch hf.Name {
case ":method", ":path", ":scheme", ":authority": case ":method", ":path", ":scheme", ":authority", ":protocol":
isRequest = true isRequest = true
case ":status": case ":status":
isResponse = true isResponse = true
@ -1498,7 +1676,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
return pseudoHeaderError(hf.Name) return pseudoHeaderError(hf.Name)
} }
// Check for duplicates. // Check for duplicates.
// This would be a bad algorithm, but N is 4. // This would be a bad algorithm, but N is 5.
// And this doesn't allocate. // And this doesn't allocate.
for _, hf2 := range pf[:i] { for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name { if hf.Name == hf2.Name {

View file

@ -15,21 +15,32 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
) )
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
// Setting DebugGoroutines to false during a test to disable goroutine debugging
// results in race detector complaints when a test leaves goroutines running before
// returning. Tests shouldn't do this, of course, but when they do it generally shows
// up as infrequent, hard-to-debug flakes. (See #66519.)
//
// Disable goroutine debugging during individual tests with an atomic bool.
// (Note that it's safe to enable/disable debugging mid-test, so the actual race condition
// here is harmless.)
var disableDebugGoroutines atomic.Bool
type goroutineLock uint64 type goroutineLock uint64
func newGoroutineLock() goroutineLock { func newGoroutineLock() goroutineLock {
if !DebugGoroutines { if !DebugGoroutines || disableDebugGoroutines.Load() {
return 0 return 0
} }
return goroutineLock(curGoroutineID()) return goroutineLock(curGoroutineID())
} }
func (g goroutineLock) check() { func (g goroutineLock) check() {
if !DebugGoroutines { if !DebugGoroutines || disableDebugGoroutines.Load() {
return return
} }
if curGoroutineID() != uint64(g) { if curGoroutineID() != uint64(g) {
@ -38,7 +49,7 @@ func (g goroutineLock) check() {
} }
func (g goroutineLock) checkNotOn() { func (g goroutineLock) checkNotOn() {
if !DebugGoroutines { if !DebugGoroutines || disableDebugGoroutines.Load() {
return return
} }
if curGoroutineID() == uint64(g) { if curGoroutineID() == uint64(g) {

View file

@ -132,11 +132,8 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// of the body, and reforward the client preface on the net.Conn this function // of the body, and reforward the client preface on the net.Conn this function
// creates. // creates.
func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) { func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) {
hijacker, ok := w.(http.Hijacker) rc := http.NewResponseController(w)
if !ok { conn, rw, err := rc.Hijack()
return nil, errors.New("h2c: connection does not support Hijack")
}
conn, rw, err := hijacker.Hijack()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -163,10 +160,6 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, nil, errors.New("h2c: connection does not support Hijack")
}
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -174,7 +167,8 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
} }
r.Body = io.NopCloser(bytes.NewBuffer(body)) r.Body = io.NopCloser(bytes.NewBuffer(body))
conn, rw, err := hijacker.Hijack() rc := http.NewResponseController(w)
conn, rw, err := rc.Hijack()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -11,21 +11,21 @@
// requires Go 1.6 or later) // requires Go 1.6 or later)
// //
// See https://http2.github.io/ for more information on HTTP/2. // See https://http2.github.io/ for more information on HTTP/2.
//
// See https://http2.golang.org/ for a test server running this code.
package http2 // import "golang.org/x/net/http2" package http2 // import "golang.org/x/net/http2"
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "net"
"net/http" "net/http"
"os" "os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
) )
@ -34,7 +34,15 @@ var (
VerboseLogs bool VerboseLogs bool
logFrameWrites bool logFrameWrites bool
logFrameReads bool logFrameReads bool
inTests bool
// Enabling extended CONNECT by causes browsers to attempt to use
// WebSockets-over-HTTP/2. This results in problems when the server's websocket
// package doesn't support extended CONNECT.
//
// Disable extended CONNECT by default for now.
//
// Issue #71128.
disableExtendedConnectProtocol = true
) )
func init() { func init() {
@ -47,6 +55,9 @@ func init() {
logFrameWrites = true logFrameWrites = true
logFrameReads = true logFrameReads = true
} }
if strings.Contains(e, "http2xconnect=1") {
disableExtendedConnectProtocol = false
}
} }
const ( const (
@ -138,6 +149,10 @@ func (s Setting) Valid() error {
if s.Val < 16384 || s.Val > 1<<24-1 { if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
} }
case SettingEnableConnectProtocol:
if s.Val != 1 && s.Val != 0 {
return ConnectionError(ErrCodeProtocol)
}
} }
return nil return nil
} }
@ -147,21 +162,25 @@ func (s Setting) Valid() error {
type SettingID uint16 type SettingID uint16
const ( const (
SettingHeaderTableSize SettingID = 0x1 SettingHeaderTableSize SettingID = 0x1
SettingEnablePush SettingID = 0x2 SettingEnablePush SettingID = 0x2
SettingMaxConcurrentStreams SettingID = 0x3 SettingMaxConcurrentStreams SettingID = 0x3
SettingInitialWindowSize SettingID = 0x4 SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5 SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6 SettingMaxHeaderListSize SettingID = 0x6
SettingEnableConnectProtocol SettingID = 0x8
SettingNoRFC7540Priorities SettingID = 0x9
) )
var settingName = map[SettingID]string{ var settingName = map[SettingID]string{
SettingHeaderTableSize: "HEADER_TABLE_SIZE", SettingHeaderTableSize: "HEADER_TABLE_SIZE",
SettingEnablePush: "ENABLE_PUSH", SettingEnablePush: "ENABLE_PUSH",
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE", SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
SettingNoRFC7540Priorities: "NO_RFC7540_PRIORITIES",
} }
func (s SettingID) String() string { func (s SettingID) String() string {
@ -210,12 +229,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error) WriteString(s string) (n int, err error)
} }
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). // A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{} type closeWaiter chan struct{}
@ -241,13 +254,17 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize // Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections. // idle memory usage with many connections.
type bufferedWriter struct { type bufferedWriter struct {
_ incomparable _ incomparable
w io.Writer // immutable conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
} }
func newBufferedWriter(w io.Writer) *bufferedWriter { func newBufferedWriter(conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{w: w} return &bufferedWriter{
conn: conn,
byteTimeout: timeout,
}
} }
// bufWriterPoolBufferSize is the size of bufio.Writer's // bufWriterPoolBufferSize is the size of bufio.Writer's
@ -274,7 +291,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) { func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil { if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer) bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w) bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw w.bw = bw
} }
return w.bw.Write(p) return w.bw.Write(p)
@ -292,6 +309,32 @@ func (w *bufferedWriter) Flush() error {
return err return err
} }
type bufferedWriterTimeoutWriter bufferedWriter
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.conn, w.byteTimeout, p)
}
// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
conn.SetWriteDeadline(time.Now().Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
// Either we finished the write, made no progress, or hit the deadline.
// Whichever it is, we're done now.
conn.SetWriteDeadline(time.Time{})
return n, err
}
}
}
func mustUint31(v int32) uint32 { func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 { if v < 0 || v > 2147483647 {
panic("out of range") panic("out of range")
@ -362,23 +405,6 @@ func (s *sorter) SortStrings(ss []string) {
s.v = save s.v = save
} }
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// - a non-empty string starting with '/'
// - the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// incomparable is a zero-width, non-comparable type. Adding it to a struct // incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add // makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first). // any size (as long as it's first).

View file

@ -29,6 +29,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -49,13 +50,18 @@ import (
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpcommon"
) )
const ( const (
prefaceTimeout = 10 * time.Second prefaceTimeout = 10 * time.Second
firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
handlerChunkWriteSize = 4 << 10 handlerChunkWriteSize = 4 << 10
defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000 maxQueuedControlFrames = 10000
) )
@ -127,6 +133,22 @@ type Server struct {
// If zero or negative, there is no timeout. // If zero or negative, there is no timeout.
IdleTimeout time.Duration IdleTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using a ping
// frame will be carried out if no frame is received on the connection.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to a ping is not received.
// If zero, a default of 15 seconds is used.
PingTimeout time.Duration
// WriteByteTimeout is the timeout after which a connection will be
// closed if no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
// If zero or negative, there is no timeout.
WriteByteTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow // MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not // control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1. // allow this to be smaller than 65535 or larger than 2^32-1.
@ -156,60 +178,13 @@ type Server struct {
state *serverInternalState state *serverInternalState
} }
func (s *Server) initialConnRecvWindowSize() int32 {
if s.MaxUploadBufferPerConnection >= initialWindowSize {
return s.MaxUploadBufferPerConnection
}
return 1 << 20
}
func (s *Server) initialStreamRecvWindowSize() int32 {
if s.MaxUploadBufferPerStream > 0 {
return s.MaxUploadBufferPerStream
}
return 1 << 20
}
func (s *Server) maxReadFrameSize() uint32 {
if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
return v
}
return defaultMaxReadFrameSize
}
func (s *Server) maxConcurrentStreams() uint32 {
if v := s.MaxConcurrentStreams; v > 0 {
return v
}
return defaultMaxStreams
}
func (s *Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (s *Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
func (s *Server) maxQueuedControlFrames() int {
// TODO: if anybody asks, add a Server field, and remember to define the
// behavior of negative values.
return maxQueuedControlFrames
}
type serverInternalState struct { type serverInternalState struct {
mu sync.Mutex mu sync.Mutex
activeConns map[*serverConn]struct{} activeConns map[*serverConn]struct{}
// Pool of error channels. This is per-Server rather than global
// because channels can't be reused across synctest bubbles.
errChanPool sync.Pool
} }
func (s *serverInternalState) registerConn(sc *serverConn) { func (s *serverInternalState) registerConn(sc *serverConn) {
@ -241,6 +216,27 @@ func (s *serverInternalState) startGracefulShutdown() {
s.mu.Unlock() s.mu.Unlock()
} }
// Global error channel pool used for uninitialized Servers.
// We use a per-Server pool when possible to avoid using channels across synctest bubbles.
var errChanPool = sync.Pool{
New: func() any { return make(chan error, 1) },
}
func (s *serverInternalState) getErrChan() chan error {
if s == nil {
return errChanPool.Get().(chan error) // Server used without calling ConfigureServer
}
return s.errChanPool.Get().(chan error)
}
func (s *serverInternalState) putErrChan(ch chan error) {
if s == nil {
errChanPool.Put(ch) // Server used without calling ConfigureServer
return
}
s.errChanPool.Put(ch)
}
// ConfigureServer adds HTTP/2 support to a net/http Server. // ConfigureServer adds HTTP/2 support to a net/http Server.
// //
// The configuration conf may be nil. // The configuration conf may be nil.
@ -253,7 +249,10 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if conf == nil { if conf == nil {
conf = new(Server) conf = new(Server)
} }
conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})} conf.state = &serverInternalState{
activeConns: make(map[*serverConn]struct{}),
errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }},
}
if h1, h2 := s, conf; h2.IdleTimeout == 0 { if h1, h2 := s, conf; h2.IdleTimeout == 0 {
if h1.IdleTimeout != 0 { if h1.IdleTimeout != 0 {
h2.IdleTimeout = h1.IdleTimeout h2.IdleTimeout = h1.IdleTimeout
@ -303,7 +302,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSNextProto == nil { if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
} }
protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
if testHookOnConn != nil { if testHookOnConn != nil {
testHookOnConn() testHookOnConn()
} }
@ -320,12 +319,31 @@ func ConfigureServer(s *http.Server, conf *Server) error {
ctx = bc.BaseContext() ctx = bc.BaseContext()
} }
conf.ServeConn(c, &ServeConnOpts{ conf.ServeConn(c, &ServeConnOpts{
Context: ctx, Context: ctx,
Handler: h, Handler: h,
BaseConfig: hs, BaseConfig: hs,
SawClientPreface: sawClientPreface,
}) })
} }
s.TLSNextProto[NextProtoTLS] = protoHandler s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
protoHandler(hs, c, h, false)
}
// The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
//
// A connection passed in this method has already had the HTTP/2 preface read from it.
s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
nc, err := unencryptedNetConnFromTLSConn(c)
if err != nil {
if lg := hs.ErrorLog; lg != nil {
lg.Print(err)
} else {
log.Print(err)
}
go c.Close()
return
}
protoHandler(hs, nc, h, true)
}
return nil return nil
} }
@ -400,16 +418,25 @@ func (o *ServeConnOpts) handler() http.Handler {
// //
// The opts parameter is optional. If nil, default values are used. // The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
if opts == nil {
opts = &ServeConnOpts{}
}
s.serveConn(c, opts, nil)
}
func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
baseCtx, cancel := serverConnBaseContext(c, opts) baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel() defer cancel()
http1srv := opts.baseConfig()
conf := configFromServer(http1srv, s)
sc := &serverConn{ sc := &serverConn{
srv: s, srv: s,
hs: opts.baseConfig(), hs: http1srv,
conn: c, conn: c,
baseCtx: baseCtx, baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(), remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c), bw: newBufferedWriter(c, conf.WriteByteTimeout),
handler: opts.handler(), handler: opts.handler(),
streams: make(map[uint32]*stream), streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult), readFrameCh: make(chan readFrameResult),
@ -419,13 +446,19 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}), doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(), advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize, initialStreamSendWindowSize: initialWindowSize,
initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize, maxFrameSize: initialMaxFrameSize,
pingTimeout: conf.PingTimeout,
countErrorFunc: conf.CountError,
serveG: newGoroutineLock(), serveG: newGoroutineLock(),
pushEnabled: true, pushEnabled: true,
sawClientPreface: opts.SawClientPreface, sawClientPreface: opts.SawClientPreface,
} }
if newf != nil {
newf(sc)
}
s.state.registerConn(sc) s.state.registerConn(sc)
defer s.state.unregisterConn(sc) defer s.state.unregisterConn(sc)
@ -439,10 +472,13 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.conn.SetWriteDeadline(time.Time{}) sc.conn.SetWriteDeadline(time.Time{})
} }
if s.NewWriteScheduler != nil { switch {
case s.NewWriteScheduler != nil:
sc.writeSched = s.NewWriteScheduler() sc.writeSched = s.NewWriteScheduler()
} else { case clientPriorityDisabled(http1srv):
sc.writeSched = newRoundRobinWriteScheduler() sc.writeSched = newRoundRobinWriteScheduler()
default:
sc.writeSched = newPriorityWriteSchedulerRFC9218()
} }
// These start at the RFC-specified defaults. If there is a higher // These start at the RFC-specified defaults. If there is a higher
@ -451,15 +487,15 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize) sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
if s.CountError != nil { if conf.CountError != nil {
fr.countError = s.CountError fr.countError = conf.CountError
} }
fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil) fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize()) fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr sc.framer = fr
if tc, ok := c.(connectionStater); ok { if tc, ok := c.(connectionStater); ok {
@ -492,7 +528,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// So for now, do nothing here again. // So for now, do nothing here again.
} }
if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error // "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of // (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated." // the prohibited cipher suites are negotiated."
@ -529,7 +565,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
opts.UpgradeRequest = nil opts.UpgradeRequest = nil
} }
sc.serve() sc.serve(conf)
} }
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@ -569,6 +605,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string remoteAddrStr string
writeSched WriteScheduler writeSched WriteScheduler
countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check(): // Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve() serveG goroutineLock // used to verify funcs are on serve()
@ -588,6 +625,7 @@ type serverConn struct {
streams map[uint32]*stream streams map[uint32]*stream
unstartedHandlers []unstartedHandler unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
initialStreamRecvWindowSize int32
maxFrameSize int32 maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@ -598,9 +636,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write needToSendGoAway bool // we need to schedule a GOAWAY frame write
pingSent bool
sentPingData [8]byte
goAwayCode ErrCode goAwayCode ErrCode
shutdownTimer *time.Timer // nil until used shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused idleTimer *time.Timer // nil if unused
readIdleTimeout time.Duration
pingTimeout time.Duration
readIdleTimer *time.Timer // nil if unused
// Owned by the writeFrameAsync goroutine: // Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer headerWriteBuf bytes.Buffer
@ -608,6 +651,23 @@ type serverConn struct {
// Used by startGracefulShutdown. // Used by startGracefulShutdown.
shutdownOnce sync.Once shutdownOnce sync.Once
// Used for RFC 9218 prioritization.
hasIntermediary bool // connection is done via an intermediary / proxy
priorityAware bool // the client has sent priority signal, meaning that it is aware of it.
}
func (sc *serverConn) writeSchedIgnoresRFC7540() bool {
switch sc.writeSched.(type) {
case *priorityWriteSchedulerRFC9218:
return true
case *randomWriteScheduler:
return true
case *roundRobinWriteScheduler:
return true
default:
return false
}
} }
func (sc *serverConn) maxHeaderListSize() uint32 { func (sc *serverConn) maxHeaderListSize() uint32 {
@ -615,11 +675,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 { if n <= 0 {
n = http.DefaultMaxHeaderBytes n = http.DefaultMaxHeaderBytes
} }
// http2's count is in a slightly different unit and includes 32 bytes per pair. return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return uint32(n + typicalHeaders*perFieldOverhead)
} }
func (sc *serverConn) curOpenStreams() uint32 { func (sc *serverConn) curOpenStreams() uint32 {
@ -775,8 +831,7 @@ const maxCachedCanonicalHeadersKeysSize = 2048
func (sc *serverConn) canonicalHeader(v string) string { func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check() sc.serveG.check()
buildCommonHeaderMapsOnce() cv, ok := httpcommon.CachedCanonicalHeader(v)
cv, ok := commonCanonHeader[v]
if ok { if ok {
return cv return cv
} }
@ -811,8 +866,8 @@ type readFrameResult struct {
// consumer is done with the frame. // consumer is done with the frame.
// It's run on its own goroutine. // It's run on its own goroutine.
func (sc *serverConn) readFrames() { func (sc *serverConn) readFrames() {
gate := make(gate) gate := make(chan struct{})
gateDone := gate.Done gateDone := func() { gate <- struct{}{} }
for { for {
f, err := sc.framer.ReadFrame() f, err := sc.framer.ReadFrame()
select { select {
@ -881,7 +936,7 @@ func (sc *serverConn) notePanic() {
} }
} }
func (sc *serverConn) serve() { func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check() sc.serveG.check()
defer sc.notePanic() defer sc.notePanic()
defer sc.conn.Close() defer sc.conn.Close()
@ -893,20 +948,27 @@ func (sc *serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
} }
settings := writeSettings{
{SettingMaxFrameSize, conf.MaxReadFrameSize},
{SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
{SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
}
if !disableExtendedConnectProtocol {
settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
}
if sc.writeSchedIgnoresRFC7540() {
settings = append(settings, Setting{SettingNoRFC7540Priorities, 1})
}
sc.writeFrame(FrameWriteRequest{ sc.writeFrame(FrameWriteRequest{
write: writeSettings{ write: settings,
{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
{SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
},
}) })
sc.unackedSettings++ sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens. // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens. // If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff)) sc.sendWindowUpdate(nil, int(diff))
} }
@ -926,11 +988,18 @@ func (sc *serverConn) serve() {
defer sc.idleTimer.Stop() defer sc.idleTimer.Stop()
} }
if conf.SendPingTimeout > 0 {
sc.readIdleTimeout = conf.SendPingTimeout
sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
defer sc.readIdleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop() defer settingsTimer.Stop()
lastFrameTime := time.Now()
loopNum := 0 loopNum := 0
for { for {
loopNum++ loopNum++
@ -944,6 +1013,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh: case res := <-sc.wroteFrameCh:
sc.wroteFrame(res) sc.wroteFrame(res)
case res := <-sc.readFrameCh: case res := <-sc.readFrameCh:
lastFrameTime = time.Now()
// Process any written frames before reading new frames from the client since a // Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started. // written frame could have triggered a new stream to be started.
if sc.writingFrameAsync { if sc.writingFrameAsync {
@ -975,6 +1045,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg: case idleTimerMsg:
sc.vlogf("connection is idle") sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo) sc.goAway(ErrCodeNo)
case readIdleTimerMsg:
sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg: case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return return
@ -997,7 +1069,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames, // If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us // but not reading them from us, assume they are trying to make us
// run out of memory. // run out of memory.
if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection") sc.vlogf("http2: too many control frames in send queue, closing connection")
return return
} }
@ -1013,12 +1085,42 @@ func (sc *serverConn) serve() {
} }
} }
func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
if sc.pingSent {
sc.logf("timeout waiting for PING response")
if f := sc.countErrorFunc; f != nil {
f("conn_close_lost_ping")
}
sc.conn.Close()
return
}
pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
now := time.Now()
if pingAt.After(now) {
// We received frames since arming the ping timer.
// Reset it for the next possible timeout.
sc.readIdleTimer.Reset(pingAt.Sub(now))
return
}
sc.pingSent = true
// Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
// is we send a PING frame containing 0s.
_, _ = rand.Read(sc.sentPingData[:])
sc.writeFrame(FrameWriteRequest{
write: &writePing{data: sc.sentPingData},
})
sc.readIdleTimer.Reset(sc.pingTimeout)
}
type serverMessage int type serverMessage int
// Message values sent to serveMsgCh. // Message values sent to serveMsgCh.
var ( var (
settingsTimerMsg = new(serverMessage) settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage) idleTimerMsg = new(serverMessage)
readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage) handlerDoneMsg = new(serverMessage)
@ -1026,6 +1128,7 @@ var (
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) } func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) } func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) { func (sc *serverConn) sendServeMsg(msg interface{}) {
@ -1072,10 +1175,6 @@ func (sc *serverConn) readPreface() error {
} }
} }
var errChanPool = sync.Pool{
New: func() interface{} { return make(chan error, 1) },
}
var writeDataPool = sync.Pool{ var writeDataPool = sync.Pool{
New: func() interface{} { return new(writeData) }, New: func() interface{} { return new(writeData) },
} }
@ -1083,7 +1182,7 @@ var writeDataPool = sync.Pool{
// writeDataFromHandler writes DATA response frames from a handler on // writeDataFromHandler writes DATA response frames from a handler on
// the given stream. // the given stream.
func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error { func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
ch := errChanPool.Get().(chan error) ch := sc.srv.state.getErrChan()
writeArg := writeDataPool.Get().(*writeData) writeArg := writeDataPool.Get().(*writeData)
*writeArg = writeData{stream.id, data, endStream} *writeArg = writeData{stream.id, data, endStream}
err := sc.writeFrameFromHandler(FrameWriteRequest{ err := sc.writeFrameFromHandler(FrameWriteRequest{
@ -1115,7 +1214,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea
return errStreamClosed return errStreamClosed
} }
} }
errChanPool.Put(ch) sc.srv.state.putErrChan(ch)
if frameWriteDone { if frameWriteDone {
writeDataPool.Put(writeArg) writeDataPool.Put(writeArg)
} }
@ -1278,6 +1377,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false sc.writingFrame = false
sc.writingFrameAsync = false sc.writingFrameAsync = false
if res.err != nil {
sc.conn.Close()
}
wr := res.wr wr := res.wr
if writeEndsStream(wr.write) { if writeEndsStream(wr.write) {
@ -1543,6 +1646,8 @@ func (sc *serverConn) processFrame(f Frame) error {
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return sc.countError("push_promise", ConnectionError(ErrCodeProtocol)) return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
case *PriorityUpdateFrame:
return sc.processPriorityUpdate(f)
default: default:
sc.vlogf("http2: server ignoring frame: %v", f.Header()) sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil return nil
@ -1552,6 +1657,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error { func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check() sc.serveG.check()
if f.IsAck() { if f.IsAck() {
if sc.pingSent && sc.sentPingData == f.Data {
// This is a response to a PING we sent.
sc.pingSent = false
sc.readIdleTimer.Reset(sc.readIdleTimeout)
}
// 6.7 PING: " An endpoint MUST NOT respond to PING frames // 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag." // containing this flag."
return nil return nil
@ -1639,7 +1749,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id) delete(sc.streams, st.id)
if len(sc.streams) == 0 { if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 { if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout) sc.idleTimer.Reset(sc.srv.IdleTimeout)
} }
if h1ServerKeepAlivesDisabled(sc.hs) { if h1ServerKeepAlivesDisabled(sc.hs) {
@ -1661,6 +1771,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
} }
} }
st.closeErr = err st.closeErr = err
st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id) sc.writeSched.CloseStream(st.id)
} }
@ -1714,6 +1825,13 @@ func (sc *serverConn) processSetting(s Setting) error {
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case SettingMaxHeaderListSize: case SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val sc.peerMaxHeaderListSize = s.Val
case SettingEnableConnectProtocol:
// Receipt of this parameter by a server does not
// have any impact
case SettingNoRFC7540Priorities:
if s.Val > 1 {
return ConnectionError(ErrCodeProtocol)
}
default: default:
// Unknown setting: "An endpoint that receives a SETTINGS // Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST // frame with any unknown or unsupported identifier MUST
@ -1984,13 +2102,33 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
if f.StreamEnded() { if f.StreamEnded() {
initialState = stateHalfClosedRemote initialState = stateHalfClosedRemote
} }
st := sc.newStream(id, 0, initialState)
// We are handling two special cases here:
// 1. When a request is sent via an intermediary, we force priority to be
// u=3,i. This is essentially a round-robin behavior, and is done to ensure
// fairness between, for example, multiple clients using the same proxy.
// 2. Until a client has shown that it is aware of RFC 9218, we make its
// streams non-incremental by default. This is done to preserve the
// historical behavior of handling streams in a round-robin manner, rather
// than one-by-one to completion.
initialPriority := defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary)
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); ok && !sc.hasIntermediary {
headerPriority, priorityAware, hasIntermediary := f.rfc9218Priority(sc.priorityAware)
initialPriority = headerPriority
sc.hasIntermediary = hasIntermediary
if priorityAware {
sc.priorityAware = true
}
}
st := sc.newStream(id, 0, initialState, initialPriority)
if f.HasPriority() { if f.HasPriority() {
if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err return err
} }
sc.writeSched.AdjustStream(st.id, f.Priority) if !sc.writeSchedIgnoresRFC7540() {
sc.writeSched.AdjustStream(st.id, f.Priority)
}
} }
rw, req, err := sc.newWriterAndRequest(st, f) rw, req, err := sc.newWriterAndRequest(st, f)
@ -2031,7 +2169,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
sc.serveG.check() sc.serveG.check()
id := uint32(1) id := uint32(1)
sc.maxClientStreamID = id sc.maxClientStreamID = id
st := sc.newStream(id, 0, stateHalfClosedRemote) st := sc.newStream(id, 0, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
st.reqTrailer = req.Trailer st.reqTrailer = req.Trailer
if st.reqTrailer != nil { if st.reqTrailer != nil {
st.trailer = make(http.Header) st.trailer = make(http.Header)
@ -2096,11 +2234,32 @@ func (sc *serverConn) processPriority(f *PriorityFrame) error {
if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil { if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err return err
} }
// We need to avoid calling AdjustStream when using the RFC 9218 write
// scheduler. Otherwise, incremental's zero value in PriorityParam will
// unexpectedly make all streams non-incremental. This causes us to process
// streams one-by-one to completion rather than doing it in a round-robin
// manner (the historical behavior), which might be unexpected to users.
if sc.writeSchedIgnoresRFC7540() {
return nil
}
sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam) sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
return nil return nil
} }
func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream { func (sc *serverConn) processPriorityUpdate(f *PriorityUpdateFrame) error {
sc.priorityAware = true
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); !ok {
return nil
}
p, ok := parseRFC9218Priority(f.Priority, sc.priorityAware)
if !ok {
return sc.countError("unparsable_priority_update", streamError(f.PrioritizedStreamID, ErrCodeProtocol))
}
sc.writeSched.AdjustStream(f.PrioritizedStreamID, p)
return nil
}
func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority PriorityParam) *stream {
sc.serveG.check() sc.serveG.check()
if id == 0 { if id == 0 {
panic("internal error: cannot create stream with id 0") panic("internal error: cannot create stream with id 0")
@ -2117,13 +2276,13 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init() st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize) st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize()) st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 { if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
} }
sc.streams[id] = st sc.streams[id] = st
sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID}) sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID, priority: priority})
if st.isPushed() { if st.isPushed() {
sc.curPushedStreams++ sc.curPushedStreams++
} else { } else {
@ -2139,19 +2298,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check() sc.serveG.check()
rp := requestParam{ rp := httpcommon.ServerRequestParam{
method: f.PseudoValue("method"), Method: f.PseudoValue("method"),
scheme: f.PseudoValue("scheme"), Scheme: f.PseudoValue("scheme"),
authority: f.PseudoValue("authority"), Authority: f.PseudoValue("authority"),
path: f.PseudoValue("path"), Path: f.PseudoValue("path"),
Protocol: f.PseudoValue("protocol"),
} }
isConnect := rp.method == "CONNECT" // extended connect is disabled, so we should not see :protocol
if disableExtendedConnectProtocol && rp.Protocol != "" {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
isConnect := rp.Method == "CONNECT"
if isConnect { if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" { if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
} }
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses: // See 8.1.2.6 Malformed Requests and Responses:
// //
// Malformed requests or responses that are detected // Malformed requests or responses that are detected
@ -2165,12 +2330,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
} }
rp.header = make(http.Header) header := make(http.Header)
rp.Header = header
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) header.Add(sc.canonicalHeader(hf.Name), hf.Value)
} }
if rp.authority == "" { if rp.Authority == "" {
rp.authority = rp.header.Get("Host") rp.Authority = header.Get("Host")
}
if rp.Protocol != "" {
header.Set(":protocol", rp.Protocol)
} }
rw, req, err := sc.newWriterAndRequestNoBody(st, rp) rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
@ -2179,7 +2348,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
} }
bodyOpen := !f.StreamEnded() bodyOpen := !f.StreamEnded()
if bodyOpen { if bodyOpen {
if vv, ok := rp.header["Content-Length"]; ok { if vv, ok := rp.Header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
req.ContentLength = int64(cl) req.ContentLength = int64(cl)
} else { } else {
@ -2195,83 +2364,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return rw, req, nil return rw, req, nil
} }
type requestParam struct { func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) {
method string
scheme, authority, path string
header http.Header
}
func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
sc.serveG.check() sc.serveG.check()
var tlsState *tls.ConnectionState // nil if not scheme https var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" { if rp.Scheme == "https" {
tlsState = sc.tlsState tlsState = sc.tlsState
} }
needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue") res := httpcommon.NewServerRequest(rp)
if needsContinue { if res.InvalidReason != "" {
rp.header.Del("Expect") return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer http.Header
for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(textproto.TrimString(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(http.Header)
}
trailer[key] = nil
}
}
}
delete(rp.header, "Trailer")
var url_ *url.URL
var requestURI string
if rp.method == "CONNECT" {
url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.path)
if err != nil {
return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
}
requestURI = rp.path
} }
body := &requestBody{ body := &requestBody{
conn: sc, conn: sc,
stream: st, stream: st,
needsContinue: needsContinue, needsContinue: res.NeedsContinue,
} }
req := &http.Request{ req := (&http.Request{
Method: rp.method, Method: rp.Method,
URL: url_, URL: res.URL,
RemoteAddr: sc.remoteAddrStr, RemoteAddr: sc.remoteAddrStr,
Header: rp.header, Header: rp.Header,
RequestURI: requestURI, RequestURI: res.RequestURI,
Proto: "HTTP/2.0", Proto: "HTTP/2.0",
ProtoMajor: 2, ProtoMajor: 2,
ProtoMinor: 0, ProtoMinor: 0,
TLS: tlsState, TLS: tlsState,
Host: rp.authority, Host: rp.Authority,
Body: body, Body: body,
Trailer: trailer, Trailer: res.Trailer,
} }).WithContext(st.ctx)
req = req.WithContext(st.ctx)
rw := sc.newResponseWriter(st, req) rw := sc.newResponseWriter(st, req)
return rw, req, nil return rw, req, nil
} }
@ -2391,7 +2515,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
// waiting for this frame to be written, so an http.Flush mid-handler // waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially // writes out the correct value of keys, before a handler later potentially
// mutates it. // mutates it.
errc = errChanPool.Get().(chan error) errc = sc.srv.state.getErrChan()
} }
if err := sc.writeFrameFromHandler(FrameWriteRequest{ if err := sc.writeFrameFromHandler(FrameWriteRequest{
write: headerData, write: headerData,
@ -2403,7 +2527,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
if errc != nil { if errc != nil {
select { select {
case err := <-errc: case err := <-errc:
errChanPool.Put(errc) sc.srv.state.putErrChan(errc)
return err return err
case <-sc.doneServing: case <-sc.doneServing:
return errClientDisconnected return errClientDisconnected
@ -2510,7 +2634,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) {
if err == io.EOF { if err == io.EOF {
b.sawEOF = true b.sawEOF = true
} }
if b.conn == nil && inTests { if b.conn == nil {
return return
} }
b.conn.noteBodyReadFromHandler(b.stream, n, err) b.conn.noteBodyReadFromHandler(b.stream, n, err)
@ -2811,6 +2935,11 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
return nil return nil
} }
func (w *responseWriter) EnableFullDuplex() error {
// We always support full duplex responses, so this is a no-op.
return nil
}
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
w.FlushError() w.FlushError()
} }
@ -3079,7 +3208,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
method: opts.Method, method: opts.Method,
url: u, url: u,
header: cloneHeader(opts.Header), header: cloneHeader(opts.Header),
done: errChanPool.Get().(chan error), done: sc.srv.state.getErrChan(),
} }
select { select {
@ -3096,7 +3225,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
case <-st.cw: case <-st.cw:
return errStreamClosed return errStreamClosed
case err := <-msg.done: case err := <-msg.done:
errChanPool.Put(msg.done) sc.srv.state.putErrChan(msg.done)
return err return err
} }
} }
@ -3159,13 +3288,13 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
// transition to "half closed (remote)" after sending the initial HEADERS, but // transition to "half closed (remote)" after sending the initial HEADERS, but
// we start in "half closed (remote)" for simplicity. // we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote. // See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
method: msg.method, Method: msg.method,
scheme: msg.url.Scheme, Scheme: msg.url.Scheme,
authority: msg.url.Host, Authority: msg.url.Host,
path: msg.url.RequestURI(), Path: msg.url.RequestURI(),
header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
}) })
if err != nil { if err != nil {
// Should not happen, since we've already validated msg.url. // Should not happen, since we've already validated msg.url.
@ -3257,7 +3386,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil { if sc == nil || sc.srv == nil {
return err return err
} }
f := sc.srv.CountError f := sc.countErrorFunc
if f == nil { if f == nil {
return err return err
} }

View file

@ -1,331 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

File diff suppressed because it is too large Load diff

32
vendor/golang.org/x/net/http2/unencrypted.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"crypto/tls"
"errors"
"net"
)
const nextProtoUnencryptedHTTP2 = "unencrypted_http2"
// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn.
//
// TLSNextProto functions accept a *tls.Conn.
//
// When passing an unencrypted HTTP/2 connection to a TLSNextProto function,
// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection.
// To be extra careful about mistakes (accidentally dropping TLS encryption in a place
// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method
// that returns the actual connection we want to use.
func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) {
conner, ok := tc.NetConn().(interface {
UnencryptedNetConn() net.Conn
})
if !ok {
return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff")
}
return conner.UnencryptedNetConn(), nil
}

View file

@ -13,6 +13,7 @@ import (
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpcommon"
) )
// writeFramer is implemented by any type that is used to write frames. // writeFramer is implemented by any type that is used to write frames.
@ -131,6 +132,16 @@ func (se StreamError) writeFrame(ctx writeContext) error {
func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max } func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max }
type writePing struct {
data [8]byte
}
func (w writePing) writeFrame(ctx writeContext) error {
return ctx.Framer().WritePing(false, w.data)
}
func (w writePing) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.data) <= max }
type writePingAck struct{ pf *PingFrame } type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error { func (w writePingAck) writeFrame(ctx writeContext) error {
@ -341,7 +352,7 @@ func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
} }
for _, k := range keys { for _, k := range keys {
vv := h[k] vv := h[k]
k, ascii := lowerHeader(k) k, ascii := httpcommon.LowerHeader(k)
if !ascii { if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x). // field names have to be ASCII characters (just as in HTTP/1.x).

View file

@ -42,6 +42,8 @@ type OpenStreamOptions struct {
// PusherID is zero if the stream was initiated by the client. Otherwise, // PusherID is zero if the stream was initiated by the client. Otherwise,
// PusherID names the stream that pushed the newly opened stream. // PusherID names the stream that pushed the newly opened stream.
PusherID uint32 PusherID uint32
// priority is used to set the priority of the newly opened stream.
priority PriorityParam
} }
// FrameWriteRequest is a request to write a frame. // FrameWriteRequest is a request to write a frame.
@ -183,45 +185,75 @@ func (wr *FrameWriteRequest) replyToWriter(err error) {
} }
// writeQueue is used by implementations of WriteScheduler. // writeQueue is used by implementations of WriteScheduler.
//
// Each writeQueue contains a queue of FrameWriteRequests, meant to store all
// FrameWriteRequests associated with a given stream. This is implemented as a
// two-stage queue: currQueue[currPos:] and nextQueue. Removing an item is done
// by incrementing currPos of currQueue. Adding an item is done by appending it
// to the nextQueue. If currQueue is empty when trying to remove an item, we
// can swap currQueue and nextQueue to remedy the situation.
// This two-stage queue is analogous to the use of two lists in Okasaki's
// purely functional queue but without the overhead of reversing the list when
// swapping stages.
//
// writeQueue also contains prev and next, this can be used by implementations
// of WriteScheduler to construct data structures that represent the order of
// writing between different streams (e.g. circular linked list).
type writeQueue struct { type writeQueue struct {
s []FrameWriteRequest currQueue []FrameWriteRequest
nextQueue []FrameWriteRequest
currPos int
prev, next *writeQueue prev, next *writeQueue
} }
func (q *writeQueue) empty() bool { return len(q.s) == 0 } func (q *writeQueue) empty() bool {
return (len(q.currQueue) - q.currPos + len(q.nextQueue)) == 0
}
func (q *writeQueue) push(wr FrameWriteRequest) { func (q *writeQueue) push(wr FrameWriteRequest) {
q.s = append(q.s, wr) q.nextQueue = append(q.nextQueue, wr)
} }
func (q *writeQueue) shift() FrameWriteRequest { func (q *writeQueue) shift() FrameWriteRequest {
if len(q.s) == 0 { if q.empty() {
panic("invalid use of queue") panic("invalid use of queue")
} }
wr := q.s[0] if q.currPos >= len(q.currQueue) {
// TODO: less copy-happy queue. q.currQueue, q.currPos, q.nextQueue = q.nextQueue, 0, q.currQueue[:0]
copy(q.s, q.s[1:]) }
q.s[len(q.s)-1] = FrameWriteRequest{} wr := q.currQueue[q.currPos]
q.s = q.s[:len(q.s)-1] q.currQueue[q.currPos] = FrameWriteRequest{}
q.currPos++
return wr return wr
} }
func (q *writeQueue) peek() *FrameWriteRequest {
if q.currPos < len(q.currQueue) {
return &q.currQueue[q.currPos]
}
if len(q.nextQueue) > 0 {
return &q.nextQueue[0]
}
return nil
}
// consume consumes up to n bytes from q.s[0]. If the frame is // consume consumes up to n bytes from q.s[0]. If the frame is
// entirely consumed, it is removed from the queue. If the frame // entirely consumed, it is removed from the queue. If the frame
// is partially consumed, the frame is kept with the consumed // is partially consumed, the frame is kept with the consumed
// bytes removed. Returns true iff any bytes were consumed. // bytes removed. Returns true iff any bytes were consumed.
func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) { func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) {
if len(q.s) == 0 { if q.empty() {
return FrameWriteRequest{}, false return FrameWriteRequest{}, false
} }
consumed, rest, numresult := q.s[0].Consume(n) consumed, rest, numresult := q.peek().Consume(n)
switch numresult { switch numresult {
case 0: case 0:
return FrameWriteRequest{}, false return FrameWriteRequest{}, false
case 1: case 1:
q.shift() q.shift()
case 2: case 2:
q.s[0] = rest *q.peek() = rest
} }
return consumed, true return consumed, true
} }
@ -230,10 +262,15 @@ type writeQueuePool []*writeQueue
// put inserts an unused writeQueue into the pool. // put inserts an unused writeQueue into the pool.
func (p *writeQueuePool) put(q *writeQueue) { func (p *writeQueuePool) put(q *writeQueue) {
for i := range q.s { for i := range q.currQueue {
q.s[i] = FrameWriteRequest{} q.currQueue[i] = FrameWriteRequest{}
} }
q.s = q.s[:0] for i := range q.nextQueue {
q.nextQueue[i] = FrameWriteRequest{}
}
q.currQueue = q.currQueue[:0]
q.nextQueue = q.nextQueue[:0]
q.currPos = 0
*p = append(*p, q) *p = append(*p, q)
} }

View file

@ -11,7 +11,7 @@ import (
) )
// RFC 7540, Section 5.3.5: the default weight is 16. // RFC 7540, Section 5.3.5: the default weight is 16.
const priorityDefaultWeight = 15 // 16 = 15 + 1 const priorityDefaultWeightRFC7540 = 15 // 16 = 15 + 1
// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. // PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
type PriorityWriteSchedulerConfig struct { type PriorityWriteSchedulerConfig struct {
@ -56,6 +56,10 @@ type PriorityWriteSchedulerConfig struct {
// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. // frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3.
// If cfg is nil, default options are used. // If cfg is nil, default options are used.
func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler { func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
return newPriorityWriteSchedulerRFC7540(cfg)
}
func newPriorityWriteSchedulerRFC7540(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
if cfg == nil { if cfg == nil {
// For justification of these defaults, see: // For justification of these defaults, see:
// https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY
@ -66,8 +70,8 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
} }
} }
ws := &priorityWriteScheduler{ ws := &priorityWriteSchedulerRFC7540{
nodes: make(map[uint32]*priorityNode), nodes: make(map[uint32]*priorityNodeRFC7540),
maxClosedNodesInTree: cfg.MaxClosedNodesInTree, maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
maxIdleNodesInTree: cfg.MaxIdleNodesInTree, maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
@ -81,32 +85,32 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
return ws return ws
} }
type priorityNodeState int type priorityNodeStateRFC7540 int
const ( const (
priorityNodeOpen priorityNodeState = iota priorityNodeOpenRFC7540 priorityNodeStateRFC7540 = iota
priorityNodeClosed priorityNodeClosedRFC7540
priorityNodeIdle priorityNodeIdleRFC7540
) )
// priorityNode is a node in an HTTP/2 priority tree. // priorityNodeRFC7540 is a node in an HTTP/2 priority tree.
// Each node is associated with a single stream ID. // Each node is associated with a single stream ID.
// See RFC 7540, Section 5.3. // See RFC 7540, Section 5.3.
type priorityNode struct { type priorityNodeRFC7540 struct {
q writeQueue // queue of pending frames to write q writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256] weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state priorityNodeState // open | closed | idle state priorityNodeStateRFC7540 // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
// These links form the priority tree. // These links form the priority tree.
parent *priorityNode parent *priorityNodeRFC7540
kids *priorityNode // start of the kids list kids *priorityNodeRFC7540 // start of the kids list
prev, next *priorityNode // doubly-linked list of siblings prev, next *priorityNodeRFC7540 // doubly-linked list of siblings
} }
func (n *priorityNode) setParent(parent *priorityNode) { func (n *priorityNodeRFC7540) setParent(parent *priorityNodeRFC7540) {
if n == parent { if n == parent {
panic("setParent to self") panic("setParent to self")
} }
@ -141,7 +145,7 @@ func (n *priorityNode) setParent(parent *priorityNode) {
} }
} }
func (n *priorityNode) addBytes(b int64) { func (n *priorityNodeRFC7540) addBytes(b int64) {
n.bytes += b n.bytes += b
for ; n != nil; n = n.parent { for ; n != nil; n = n.parent {
n.subtreeBytes += b n.subtreeBytes += b
@ -154,7 +158,7 @@ func (n *priorityNode) addBytes(b int64) {
// //
// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true // f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
// if any ancestor p of n is still open (ignoring the root node). // if any ancestor p of n is still open (ignoring the root node).
func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool { func (n *priorityNodeRFC7540) walkReadyInOrder(openParent bool, tmp *[]*priorityNodeRFC7540, f func(*priorityNodeRFC7540, bool) bool) bool {
if !n.q.empty() && f(n, openParent) { if !n.q.empty() && f(n, openParent) {
return true return true
} }
@ -165,7 +169,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
// Don't consider the root "open" when updating openParent since // Don't consider the root "open" when updating openParent since
// we can't send data frames on the root stream (only control frames). // we can't send data frames on the root stream (only control frames).
if n.id != 0 { if n.id != 0 {
openParent = openParent || (n.state == priorityNodeOpen) openParent = openParent || (n.state == priorityNodeOpenRFC7540)
} }
// Common case: only one kid or all kids have the same weight. // Common case: only one kid or all kids have the same weight.
@ -195,7 +199,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
*tmp = append(*tmp, n.kids) *tmp = append(*tmp, n.kids)
n.kids.setParent(nil) n.kids.setParent(nil)
} }
sort.Sort(sortPriorityNodeSiblings(*tmp)) sort.Sort(sortPriorityNodeSiblingsRFC7540(*tmp))
for i := len(*tmp) - 1; i >= 0; i-- { for i := len(*tmp) - 1; i >= 0; i-- {
(*tmp)[i].setParent(n) // setParent inserts at the head of n.kids (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids
} }
@ -207,15 +211,15 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
return false return false
} }
type sortPriorityNodeSiblings []*priorityNode type sortPriorityNodeSiblingsRFC7540 []*priorityNodeRFC7540
func (z sortPriorityNodeSiblings) Len() int { return len(z) } func (z sortPriorityNodeSiblingsRFC7540) Len() int { return len(z) }
func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } func (z sortPriorityNodeSiblingsRFC7540) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z sortPriorityNodeSiblings) Less(i, k int) bool { func (z sortPriorityNodeSiblingsRFC7540) Less(i, k int) bool {
// Prefer the subtree that has sent fewer bytes relative to its weight. // Prefer the subtree that has sent fewer bytes relative to its weight.
// See sections 5.3.2 and 5.3.4. // See sections 5.3.2 and 5.3.4.
wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) wi, bi := float64(z[i].weight)+1, float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) wk, bk := float64(z[k].weight)+1, float64(z[k].subtreeBytes)
if bi == 0 && bk == 0 { if bi == 0 && bk == 0 {
return wi >= wk return wi >= wk
} }
@ -225,13 +229,13 @@ func (z sortPriorityNodeSiblings) Less(i, k int) bool {
return bi/bk <= wi/wk return bi/bk <= wi/wk
} }
type priorityWriteScheduler struct { type priorityWriteSchedulerRFC7540 struct {
// root is the root of the priority tree, where root.id = 0. // root is the root of the priority tree, where root.id = 0.
// The root queues control frames that are not associated with any stream. // The root queues control frames that are not associated with any stream.
root priorityNode root priorityNodeRFC7540
// nodes maps stream ids to priority tree nodes. // nodes maps stream ids to priority tree nodes.
nodes map[uint32]*priorityNode nodes map[uint32]*priorityNodeRFC7540
// maxID is the maximum stream id in nodes. // maxID is the maximum stream id in nodes.
maxID uint32 maxID uint32
@ -239,7 +243,7 @@ type priorityWriteScheduler struct {
// lists of nodes that have been closed or are idle, but are kept in // lists of nodes that have been closed or are idle, but are kept in
// the tree for improved prioritization. When the lengths exceed either // the tree for improved prioritization. When the lengths exceed either
// maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
closedNodes, idleNodes []*priorityNode closedNodes, idleNodes []*priorityNodeRFC7540
// From the config. // From the config.
maxClosedNodesInTree int maxClosedNodesInTree int
@ -248,19 +252,19 @@ type priorityWriteScheduler struct {
enableWriteThrottle bool enableWriteThrottle bool
// tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
tmp []*priorityNode tmp []*priorityNodeRFC7540
// pool of empty queues for reuse. // pool of empty queues for reuse.
queuePool writeQueuePool queuePool writeQueuePool
} }
func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { func (ws *priorityWriteSchedulerRFC7540) OpenStream(streamID uint32, options OpenStreamOptions) {
// The stream may be currently idle but cannot be opened or closed. // The stream may be currently idle but cannot be opened or closed.
if curr := ws.nodes[streamID]; curr != nil { if curr := ws.nodes[streamID]; curr != nil {
if curr.state != priorityNodeIdle { if curr.state != priorityNodeIdleRFC7540 {
panic(fmt.Sprintf("stream %d already opened", streamID)) panic(fmt.Sprintf("stream %d already opened", streamID))
} }
curr.state = priorityNodeOpen curr.state = priorityNodeOpenRFC7540
return return
} }
@ -272,11 +276,11 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
if parent == nil { if parent == nil {
parent = &ws.root parent = &ws.root
} }
n := &priorityNode{ n := &priorityNodeRFC7540{
q: *ws.queuePool.get(), q: *ws.queuePool.get(),
id: streamID, id: streamID,
weight: priorityDefaultWeight, weight: priorityDefaultWeightRFC7540,
state: priorityNodeOpen, state: priorityNodeOpenRFC7540,
} }
n.setParent(parent) n.setParent(parent)
ws.nodes[streamID] = n ws.nodes[streamID] = n
@ -285,24 +289,23 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
} }
} }
func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { func (ws *priorityWriteSchedulerRFC7540) CloseStream(streamID uint32) {
if streamID == 0 { if streamID == 0 {
panic("violation of WriteScheduler interface: cannot close stream 0") panic("violation of WriteScheduler interface: cannot close stream 0")
} }
if ws.nodes[streamID] == nil { if ws.nodes[streamID] == nil {
panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
} }
if ws.nodes[streamID].state != priorityNodeOpen { if ws.nodes[streamID].state != priorityNodeOpenRFC7540 {
panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
} }
n := ws.nodes[streamID] n := ws.nodes[streamID]
n.state = priorityNodeClosed n.state = priorityNodeClosedRFC7540
n.addBytes(-n.bytes) n.addBytes(-n.bytes)
q := n.q q := n.q
ws.queuePool.put(&q) ws.queuePool.put(&q)
n.q.s = nil
if ws.maxClosedNodesInTree > 0 { if ws.maxClosedNodesInTree > 0 {
ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
} else { } else {
@ -310,7 +313,7 @@ func (ws *priorityWriteScheduler) CloseStream(streamID uint32) {
} }
} }
func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { func (ws *priorityWriteSchedulerRFC7540) AdjustStream(streamID uint32, priority PriorityParam) {
if streamID == 0 { if streamID == 0 {
panic("adjustPriority on root") panic("adjustPriority on root")
} }
@ -324,11 +327,11 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
return return
} }
ws.maxID = streamID ws.maxID = streamID
n = &priorityNode{ n = &priorityNodeRFC7540{
q: *ws.queuePool.get(), q: *ws.queuePool.get(),
id: streamID, id: streamID,
weight: priorityDefaultWeight, weight: priorityDefaultWeightRFC7540,
state: priorityNodeIdle, state: priorityNodeIdleRFC7540,
} }
n.setParent(&ws.root) n.setParent(&ws.root)
ws.nodes[streamID] = n ws.nodes[streamID] = n
@ -340,7 +343,7 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
parent := ws.nodes[priority.StreamDep] parent := ws.nodes[priority.StreamDep]
if parent == nil { if parent == nil {
n.setParent(&ws.root) n.setParent(&ws.root)
n.weight = priorityDefaultWeight n.weight = priorityDefaultWeightRFC7540
return return
} }
@ -381,8 +384,8 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
n.weight = priority.Weight n.weight = priority.Weight
} }
func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { func (ws *priorityWriteSchedulerRFC7540) Push(wr FrameWriteRequest) {
var n *priorityNode var n *priorityNodeRFC7540
if wr.isControl() { if wr.isControl() {
n = &ws.root n = &ws.root
} else { } else {
@ -401,8 +404,8 @@ func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) {
n.q.push(wr) n.q.push(wr)
} }
func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { func (ws *priorityWriteSchedulerRFC7540) Pop() (wr FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool { ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNodeRFC7540, openParent bool) bool {
limit := int32(math.MaxInt32) limit := int32(math.MaxInt32)
if openParent { if openParent {
limit = ws.writeThrottleLimit limit = ws.writeThrottleLimit
@ -428,7 +431,7 @@ func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) {
return wr, ok return wr, ok
} }
func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) { func (ws *priorityWriteSchedulerRFC7540) addClosedOrIdleNode(list *[]*priorityNodeRFC7540, maxSize int, n *priorityNodeRFC7540) {
if maxSize == 0 { if maxSize == 0 {
return return
} }
@ -442,9 +445,9 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
*list = append(*list, n) *list = append(*list, n)
} }
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { func (ws *priorityWriteSchedulerRFC7540) removeNode(n *priorityNodeRFC7540) {
for k := n.kids; k != nil; k = k.next { for n.kids != nil {
k.setParent(n.parent) n.kids.setParent(n.parent)
} }
n.setParent(nil) n.setParent(nil)
delete(ws.nodes, n.id) delete(ws.nodes, n.id)

View file

@ -0,0 +1,224 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"fmt"
"math"
)
type streamMetadata struct {
location *writeQueue
priority PriorityParam
}
type priorityWriteSchedulerRFC9218 struct {
// control contains control frames (SETTINGS, PING, etc.).
control writeQueue
// heads contain the head of a circular list of streams.
// We put these heads within a nested array that represents urgency and
// incremental, as defined in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-priority-parameters.
// 8 represents u=0 up to u=7, and 2 represents i=false and i=true.
heads [8][2]*writeQueue
// streams contains a mapping between each stream ID and their metadata, so
// we can quickly locate them when needing to, for example, adjust their
// priority.
streams map[uint32]streamMetadata
// queuePool are empty queues for reuse.
queuePool writeQueuePool
// prioritizeIncremental is used to determine whether we should prioritize
// incremental streams or not, when urgency is the same in a given Pop()
// call.
prioritizeIncremental bool
// priorityUpdateBuf is used to buffer the most recent PRIORITY_UPDATE we
// receive per https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
priorityUpdateBuf struct {
// streamID being 0 means that the buffer is empty. This is a safe
// assumption as PRIORITY_UPDATE for stream 0 is a PROTOCOL_ERROR.
streamID uint32
priority PriorityParam
}
}
func newPriorityWriteSchedulerRFC9218() WriteScheduler {
ws := &priorityWriteSchedulerRFC9218{
streams: make(map[uint32]streamMetadata),
}
return ws
}
func (ws *priorityWriteSchedulerRFC9218) OpenStream(streamID uint32, opt OpenStreamOptions) {
if ws.streams[streamID].location != nil {
panic(fmt.Errorf("stream %d already opened", streamID))
}
if streamID == ws.priorityUpdateBuf.streamID {
ws.priorityUpdateBuf.streamID = 0
opt.priority = ws.priorityUpdateBuf.priority
}
q := ws.queuePool.get()
ws.streams[streamID] = streamMetadata{
location: q,
priority: opt.priority,
}
u, i := opt.priority.urgency, opt.priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
}
func (ws *priorityWriteSchedulerRFC9218) CloseStream(streamID uint32) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
return
}
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
delete(ws.streams, streamID)
ws.queuePool.put(q)
}
func (ws *priorityWriteSchedulerRFC9218) AdjustStream(streamID uint32, priority PriorityParam) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
ws.priorityUpdateBuf.streamID = streamID
ws.priorityUpdateBuf.priority = priority
return
}
// Remove stream from current location.
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
// Insert stream to the new queue.
u, i = priority.urgency, priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
// Update the metadata.
ws.streams[streamID] = streamMetadata{
location: q,
priority: priority,
}
}
func (ws *priorityWriteSchedulerRFC9218) Push(wr FrameWriteRequest) {
if wr.isControl() {
ws.control.push(wr)
return
}
q := ws.streams[wr.StreamID()].location
if q == nil {
// This is a closed stream.
// wr should not be a HEADERS or DATA frame.
// We push the request onto the control queue.
if wr.DataSize() > 0 {
panic("add DATA on non-open stream")
}
ws.control.push(wr)
return
}
q.push(wr)
}
func (ws *priorityWriteSchedulerRFC9218) Pop() (FrameWriteRequest, bool) {
// Control and RST_STREAM frames first.
if !ws.control.empty() {
return ws.control.shift(), true
}
// On the next Pop(), we want to prioritize incremental if we prioritized
// non-incremental request of the same urgency this time. Vice-versa.
// i.e. when there are incremental and non-incremental requests at the same
// priority, we give 50% of our bandwidth to the incremental ones in
// aggregate and 50% to the first non-incremental one (since
// non-incremental streams do not use round-robin writes).
ws.prioritizeIncremental = !ws.prioritizeIncremental
// Always prioritize lowest u (i.e. highest urgency level).
for u := range ws.heads {
for i := range ws.heads[u] {
// When we want to prioritize incremental, we try to pop i=true
// first before i=false when u is the same.
if ws.prioritizeIncremental {
i = (i + 1) % 2
}
q := ws.heads[u][i]
if q == nil {
continue
}
for {
if wr, ok := q.consume(math.MaxInt32); ok {
if i == 1 {
// For incremental streams, we update head to q.next so
// we can round-robin between multiple streams that can
// immediately benefit from partial writes.
ws.heads[u][i] = q.next
} else {
// For non-incremental streams, we try to finish one to
// completion rather than doing round-robin. However,
// we update head here so that if q.consume() is !ok
// (e.g. the stream has no more frame to consume), head
// is updated to the next q that has frames to consume
// on future iterations. This way, we do not prioritize
// writing to unavailable stream on next Pop() calls,
// preventing head-of-line blocking.
ws.heads[u][i] = q
}
return wr, true
}
q = q.next
if q == ws.heads[u][i] {
break
}
}
}
}
return FrameWriteRequest{}, false
}

View file

@ -25,7 +25,7 @@ type roundRobinWriteScheduler struct {
} }
// newRoundRobinWriteScheduler constructs a new write scheduler. // newRoundRobinWriteScheduler constructs a new write scheduler.
// The round robin scheduler priorizes control frames // The round robin scheduler prioritizes control frames
// like SETTINGS and PING over DATA frames. // like SETTINGS and PING over DATA frames.
// When there are no control frames to send, it performs a round-robin // When there are no control frames to send, it performs a round-robin
// selection from the ready streams. // selection from the ready streams.

53
vendor/golang.org/x/net/internal/httpcommon/ascii.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpcommon
import "strings"
// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
// contains helper functions which may use Unicode-aware functions which would
// otherwise be unsafe and could introduce vulnerabilities if used improperly.
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func asciiEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lower(s[i]) != lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// isASCIIPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func isASCIIPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// asciiToLower returns the lowercase version of s if s is ASCII and printable,
// and whether or not it was.
func asciiToLower(s string) (lower string, ok bool) {
if !isASCIIPrint(s) {
return "", false
}
return strings.ToLower(s), true
}

View file

@ -1,11 +1,11 @@
// Copyright 2014 The Go Authors. All rights reserved. // Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package http2 package httpcommon
import ( import (
"net/http" "net/textproto"
"sync" "sync"
) )
@ -82,13 +82,15 @@ func buildCommonHeaderMaps() {
commonLowerHeader = make(map[string]string, len(common)) commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common))
for _, v := range common { for _, v := range common {
chk := http.CanonicalHeaderKey(v) chk := textproto.CanonicalMIMEHeaderKey(v)
commonLowerHeader[chk] = v commonLowerHeader[chk] = v
commonCanonHeader[v] = chk commonCanonHeader[v] = chk
} }
} }
func lowerHeader(v string) (lower string, ascii bool) { // LowerHeader returns the lowercase form of a header name,
// used on the wire for HTTP/2 and HTTP/3 requests.
func LowerHeader(v string) (lower string, ascii bool) {
buildCommonHeaderMapsOnce() buildCommonHeaderMapsOnce()
if s, ok := commonLowerHeader[v]; ok { if s, ok := commonLowerHeader[v]; ok {
return s, true return s, true
@ -96,10 +98,18 @@ func lowerHeader(v string) (lower string, ascii bool) {
return asciiToLower(v) return asciiToLower(v)
} }
func canonicalHeader(v string) string { // CanonicalHeader canonicalizes a header name. (For example, "host" becomes "Host".)
func CanonicalHeader(v string) string {
buildCommonHeaderMapsOnce() buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok { if s, ok := commonCanonHeader[v]; ok {
return s return s
} }
return http.CanonicalHeaderKey(v) return textproto.CanonicalMIMEHeaderKey(v)
}
// CachedCanonicalHeader returns the canonical form of a well-known header name.
func CachedCanonicalHeader(v string) (string, bool) {
buildCommonHeaderMapsOnce()
s, ok := commonCanonHeader[v]
return s, ok
} }

467
vendor/golang.org/x/net/internal/httpcommon/request.go generated vendored Normal file
View file

@ -0,0 +1,467 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpcommon
import (
"context"
"errors"
"fmt"
"net/http/httptrace"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
)
var (
ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
)
// Request is a subset of http.Request.
// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http
// without creating a dependency cycle.
type Request struct {
URL *url.URL
Method string
Host string
Header map[string][]string
Trailer map[string][]string
ActualContentLength int64 // 0 means 0, -1 means unknown
}
// EncodeHeadersParam is parameters to EncodeHeaders.
type EncodeHeadersParam struct {
Request Request
// AddGzipHeader indicates that an "accept-encoding: gzip" header should be
// added to the request.
AddGzipHeader bool
// PeerMaxHeaderListSize, when non-zero, is the peer's MAX_HEADER_LIST_SIZE setting.
PeerMaxHeaderListSize uint64
// DefaultUserAgent is the User-Agent header to send when the request
// neither contains a User-Agent nor disables it.
DefaultUserAgent string
}
// EncodeHeadersResult is the result of EncodeHeaders.
type EncodeHeadersResult struct {
HasBody bool
HasTrailers bool
}
// EncodeHeaders constructs request headers common to HTTP/2 and HTTP/3.
// It validates a request and calls headerf with each pseudo-header and header
// for the request.
// The headerf function is called with the validated, canonicalized header name.
func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
req := param.Request
// Check for invalid connection-level headers.
if err := checkConnHeaders(req.Header); err != nil {
return res, err
}
if req.URL == nil {
return res, errors.New("Request.URL is nil")
}
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return res, err
}
if !httpguts.ValidHostHeader(host) {
return res, errors.New("invalid Host header")
}
// isNormalConnect is true if this is a non-extended CONNECT request.
isNormalConnect := false
var protocol string
if vv := req.Header[":protocol"]; len(vv) > 0 {
protocol = vv[0]
}
if req.Method == "CONNECT" && protocol == "" {
isNormalConnect = true
} else if protocol != "" && req.Method != "CONNECT" {
return res, errors.New("invalid :protocol header in non-CONNECT request")
}
// Validate the path, except for non-extended CONNECT requests which have no path.
var path string
if !isNormalConnect {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return res, fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
if err := validateHeaders(req.Header); err != "" {
return res, fmt.Errorf("invalid HTTP header %s", err)
}
if err := validateHeaders(req.Trailer); err != "" {
return res, fmt.Errorf("invalid HTTP trailer %s", err)
}
trailers, err := commaSeparatedTrailers(req.Trailer)
if err != nil {
return res, err
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production, see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
m := req.Method
if m == "" {
m = "GET"
}
f(":method", m)
if !isNormalConnect {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if protocol != "" {
f(":protocol", protocol)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if asciiEqualFold(k, "connection") ||
asciiEqualFold(k, "proxy-connection") ||
asciiEqualFold(k, "transfer-encoding") ||
asciiEqualFold(k, "upgrade") ||
asciiEqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if asciiEqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
} else if asciiEqualFold(k, "cookie") {
// Per 8.1.2.5 To allow for better compression efficiency, the
// Cookie header field MAY be split into separate header fields,
// each with one or more cookie-pairs.
for _, v := range vv {
for {
p := strings.IndexByte(v, ';')
if p < 0 {
break
}
f("cookie", v[:p])
p++
// strip space after semicolon if any.
for p+1 <= len(v) && v[p] == ' ' {
p++
}
v = v[p:]
}
if len(v) > 0 {
f("cookie", v)
}
}
continue
} else if k == ":protocol" {
// :protocol pseudo-header was already sent above.
continue
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
}
if param.AddGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", param.DefaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
if param.PeerMaxHeaderListSize > 0 {
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
if hlSize > param.PeerMaxHeaderListSize {
return res, ErrRequestHeaderListSize
}
}
trace := httptrace.ContextClientTrace(ctx)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name, ascii := LowerHeader(name)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
return
}
headerf(name, value)
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(name, []string{value})
}
})
res.HasBody = req.ActualContentLength != 0
res.HasTrailers = trailers != ""
return res, nil
}
// IsRequestGzip reports whether we should add an Accept-Encoding: gzip header
// for a request.
func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !disableCompression &&
len(header["Accept-Encoding"]) == 0 &&
len(header["Range"]) == 0 &&
method != "HEAD" {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// http://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
return true
}
return false
}
// checkConnHeaders checks whether req has any invalid connection-level headers.
//
// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2-3
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.2.2-1
//
// Certain headers are special-cased as okay but not transmitted later.
// For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding.
func checkConnHeaders(h map[string][]string) error {
if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("invalid Upgrade request header: %q", vv)
}
if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
}
if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("invalid Connection request header: %q", vv)
}
return nil
}
func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
keys := make([]string, 0, len(trailer))
for k := range trailer {
k = CanonicalHeader(k)
switch k {
case "Transfer-Encoding", "Trailer", "Content-Length":
return "", fmt.Errorf("invalid Trailer key %q", k)
}
keys = append(keys, k)
}
if len(keys) > 0 {
sort.Strings(keys)
return strings.Join(keys, ","), nil
}
return "", nil
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// - a non-empty string starting with '/'
// - the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
func validateHeaders(hdrs map[string][]string) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
// shouldSendReqContentLength reports whether we should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// ServerRequestParam is parameters to NewServerRequest.
type ServerRequestParam struct {
Method string
Scheme, Authority, Path string
Protocol string
Header map[string][]string
}
// ServerRequestResult is the result of NewServerRequest.
type ServerRequestResult struct {
// Various http.Request fields.
URL *url.URL
RequestURI string
Trailer map[string][]string
NeedsContinue bool // client provided an "Expect: 100-continue" header
// If the request should be rejected, this is a short string suitable for passing
// to the http2 package's CountError function.
// It might be a bit odd to return errors this way rather than returning an error,
// but this ensures we don't forget to include a CountError reason.
InvalidReason string
}
func NewServerRequest(rp ServerRequestParam) ServerRequestResult {
needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue")
if needsContinue {
delete(rp.Header, "Expect")
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.Header["Cookie"]; len(cookies) > 1 {
rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")}
}
// Setup Trailers
var trailer map[string][]string
for _, v := range rp.Header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(map[string][]string)
}
trailer[key] = nil
}
}
}
delete(rp.Header, "Trailer")
// "':authority' MUST NOT include the deprecated userinfo subcomponent
// for "http" or "https" schemed URIs."
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") {
return ServerRequestResult{
InvalidReason: "userinfo_in_authority",
}
}
var url_ *url.URL
var requestURI string
if rp.Method == "CONNECT" && rp.Protocol == "" {
url_ = &url.URL{Host: rp.Authority}
requestURI = rp.Authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.Path)
if err != nil {
return ServerRequestResult{
InvalidReason: "bad_path",
}
}
requestURI = rp.Path
}
return ServerRequestResult{
URL: url_,
NeedsContinue: needsContinue,
RequestURI: requestURI,
Trailer: trailer,
}
}

665
vendor/golang.org/x/net/internal/httpsfv/httpsfv.go generated vendored Normal file
View file

@ -0,0 +1,665 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpsfv provides functionality for dealing with HTTP Structured
// Field Values.
package httpsfv
import (
"slices"
"strconv"
"strings"
"time"
"unicode/utf8"
)
func isLCAlpha(b byte) bool {
return (b >= 'a' && b <= 'z')
}
func isAlpha(b byte) bool {
return isLCAlpha(b) || (b >= 'A' && b <= 'Z')
}
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
func isVChar(b byte) bool {
return b >= 0x21 && b <= 0x7e
}
func isSP(b byte) bool {
return b == 0x20
}
func isTChar(b byte) bool {
if isAlpha(b) || isDigit(b) {
return true
}
return slices.Contains([]byte{'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~'}, b)
}
func countLeftWhitespace(s string) int {
i := 0
for _, ch := range []byte(s) {
if ch != ' ' && ch != '\t' {
break
}
i++
}
return i
}
// https://www.rfc-editor.org/rfc/rfc4648#section-8.
func decOctetHex(ch1, ch2 byte) (ch byte, ok bool) {
decBase16 := func(in byte) (out byte, ok bool) {
if !isDigit(in) && !(in >= 'a' && in <= 'f') {
return 0, false
}
if isDigit(in) {
return in - '0', true
}
return in - 'a' + 10, true
}
if ch1, ok = decBase16(ch1); !ok {
return 0, ok
}
if ch2, ok = decBase16(ch2); !ok {
return 0, ok
}
return ch1<<4 | ch2, true
}
// ParseList parses a list from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a list, it will call the given
// function using each of the members and parameters contained in the list.
// This allows the caller to extract information out of the list.
//
// This function will return once it encounters the end of the string, or
// something that is not a list. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-list.
func ParseList(s string, f func(member, param string)) (ok bool) {
for len(s) != 0 {
var member, param string
if len(s) != 0 && s[0] == '(' {
if member, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if member, s, ok = consumeBareItem(s); !ok {
return ok
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(member, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] != ',' {
return false
}
s = s[1:]
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// consumeBareInnerList consumes an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// except for the inner list's top-most parameter.
// For example, given `(a;b c;d);e`, it will consume only `(a;b c;d)`.
func consumeBareInnerList(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '(' {
return "", s, false
}
rest = s[1:]
for len(rest) != 0 {
var bareItem, param string
rest = rest[countLeftWhitespace(rest):]
if len(rest) != 0 && rest[0] == ')' {
rest = rest[1:]
break
}
if bareItem, rest, ok = consumeBareItem(rest); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if len(rest) == 0 || (rest[0] != ')' && !isSP(rest[0])) {
return "", s, false
}
if f != nil {
f(bareItem, param)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseBareInnerList parses a bare inner list from a given HTTP Structured
// Field Values.
//
// We define a bare inner list as an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// without the top-most parameter of the inner list. For example, given the
// inner list `(a;b c;d);e`, the bare inner list would be `(a;b c;d)`.
//
// Given an HTTP SFV string that represents a bare inner list, it will call the
// given function using each of the bare item and parameter within the bare
// inner list. This allows the caller to extract information out of the bare
// inner list.
//
// This function will return once it encounters the end of the bare inner list,
// or something that is not a bare inner list. If it cannot consume the entire
// given string, the ok value returned will be false.
func ParseBareInnerList(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeBareInnerList(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func consumeItem(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
var bareItem, param string
if bareItem, rest, ok = consumeBareItem(s); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if f != nil {
f(bareItem, param)
}
return s[:len(s)-len(rest)], rest, true
}
// ParseItem parses an item from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents an item, it will call the given
// function once, with the bare item and the parameter of the item. This allows
// the caller to extract information out of the item.
//
// This function will return once it encounters the end of the string, or
// something that is not an item. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func ParseItem(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeItem(s, f)
return rest == "" && ok
}
// ParseDictionary parses a dictionary from a given HTTP Structured Field
// Values.
//
// Given an HTTP SFV string that represents a dictionary, it will call the
// given function using each of the keys, values, and parameters contained in
// the dictionary. This allows the caller to extract information out of the
// dictionary.
//
// This function will return once it encounters the end of the string, or
// something that is not a dictionary. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-dictionary.
func ParseDictionary(s string, f func(key, val, param string)) (ok bool) {
for len(s) != 0 {
var key, val, param string
val = "?1" // Default value for empty val is boolean true.
if key, s, ok = consumeKey(s); !ok {
return ok
}
if len(s) != 0 && s[0] == '=' {
s = s[1:]
if len(s) != 0 && s[0] == '(' {
if val, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if val, s, ok = consumeBareItem(s); !ok {
return ok
}
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(key, val, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] == ',' {
s = s[1:]
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func consumeParameter(s string, f func(key, val string)) (consumed, rest string, ok bool) {
rest = s
for len(rest) != 0 {
var key, val string
val = "?1" // Default value for empty val is boolean true.
if rest[0] != ';' {
break
}
rest = rest[1:]
rest = rest[countLeftWhitespace(rest):]
key, rest, ok = consumeKey(rest)
if !ok {
return "", s, ok
}
if len(rest) != 0 && rest[0] == '=' {
rest = rest[1:]
val, rest, ok = consumeBareItem(rest)
if !ok {
return "", s, ok
}
}
if f != nil {
f(key, val)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseParameter parses a parameter from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a parameter, it will call the given
// function using each of the keys and values contained in the parameter. This
// allows the caller to extract information out of the parameter.
//
// This function will return once it encounters the end of the string, or
// something that is not a parameter. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func ParseParameter(s string, f func(key, val string)) (ok bool) {
_, rest, ok := consumeParameter(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-key.
func consumeKey(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isLCAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isLCAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("_-.*"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func consumeIntegerOrDecimal(s string) (consumed, rest string, ok bool) {
var i, signOffset, periodIndex int
var isDecimal bool
if i < len(s) && s[i] == '-' {
i++
signOffset++
}
if i >= len(s) {
return "", s, false
}
if !isDigit(s[i]) {
return "", s, false
}
for i < len(s) {
ch := s[i]
if isDigit(ch) {
i++
continue
}
if !isDecimal && ch == '.' {
if i-signOffset > 12 {
return "", s, false
}
periodIndex = i
isDecimal = true
i++
continue
}
break
}
if !isDecimal && i-signOffset > 15 {
return "", s, false
}
if isDecimal {
if i-signOffset > 16 {
return "", s, false
}
if s[i-1] == '.' {
return "", s, false
}
if i-periodIndex-1 > 3 {
return "", s, false
}
}
return s[:i], s[i:], true
}
// ParseInteger parses an integer from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid integer. It returns the
// parsed integer and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseInteger(s string) (parsed int64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
return n, true
}
return 0, false
}
// ParseDecimal parses a decimal from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid decimal. It returns the
// parsed decimal and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseDecimal(s string) (parsed float64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if !strings.Contains(s, ".") {
return 0, false
}
if n, err := strconv.ParseFloat(s, 64); err == nil {
return n, true
}
return 0, false
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func consumeString(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '"' {
return "", s, false
}
for i := 1; i < len(s); i++ {
switch ch := s[i]; ch {
case '\\':
if i+1 >= len(s) {
return "", s, false
}
i++
if ch = s[i]; ch != '"' && ch != '\\' {
return "", s, false
}
case '"':
return s[:i+1], s[i+1:], true
default:
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
}
}
return "", s, false
}
// ParseString parses a Go string from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid string. It returns the
// parsed string and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func ParseString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeString(s); !ok || rest != "" {
return "", false
}
return s[1 : len(s)-1], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func consumeToken(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isTChar(ch) && !slices.Contains([]byte(":/"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// ParseToken parses a token from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid token. It returns the
// parsed token and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func ParseToken(s string) (parsed string, ok bool) {
if _, rest, ok := consumeToken(s); !ok || rest != "" {
return "", false
}
return s, true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func consumeByteSequence(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != ':' {
return "", s, false
}
for i := 1; i < len(s); i++ {
if ch := s[i]; ch == ':' {
return s[:i+1], s[i+1:], true
}
if ch := s[i]; !isAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("+/="), ch) {
return "", s, false
}
}
return "", s, false
}
// ParseByteSequence parses a byte sequence from a given HTTP Structured Field
// Values.
//
// The entire HTTP SFV string must consist of a valid byte sequence. It returns
// the parsed byte sequence and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func ParseByteSequence(s string) (parsed []byte, ok bool) {
if _, rest, ok := consumeByteSequence(s); !ok || rest != "" {
return nil, false
}
return []byte(s[1 : len(s)-1]), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func consumeBoolean(s string) (consumed, rest string, ok bool) {
if len(s) >= 2 && (s[:2] == "?0" || s[:2] == "?1") {
return s[:2], s[2:], true
}
return "", s, false
}
// ParseBoolean parses a boolean from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid boolean. It returns the
// parsed boolean and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func ParseBoolean(s string) (parsed bool, ok bool) {
if _, rest, ok := consumeBoolean(s); !ok || rest != "" {
return false, false
}
return s == "?1", true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func consumeDate(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '@' {
return "", s, false
}
if _, rest, ok = consumeIntegerOrDecimal(s[1:]); !ok {
return "", s, ok
}
consumed = s[:len(s)-len(rest)]
if slices.Contains([]byte(consumed), '.') {
return "", s, false
}
return consumed, rest, ok
}
// ParseDate parses a date from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid date. It returns the
// parsed date and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func ParseDate(s string) (parsed time.Time, ok bool) {
if _, rest, ok := consumeDate(s); !ok || rest != "" {
return time.Time{}, false
}
if n, ok := ParseInteger(s[1:]); !ok {
return time.Time{}, false
} else {
return time.Unix(n, 0), true
}
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func consumeDisplayString(s string) (consumed, rest string, ok bool) {
// To prevent excessive allocation, especially when input is large, we
// maintain a buffer of 4 bytes to keep track of the last rune we
// encounter. This way, we can validate that the display string conforms to
// UTF-8 without actually building the whole string.
var lastRune [4]byte
var runeLen int
isPartOfValidRune := func(ch byte) bool {
lastRune[runeLen] = ch
runeLen++
if utf8.FullRune(lastRune[:runeLen]) {
r, s := utf8.DecodeRune(lastRune[:runeLen])
if r == utf8.RuneError {
return false
}
copy(lastRune[:], lastRune[s:runeLen])
runeLen -= s
return true
}
return runeLen <= 4
}
if len(s) <= 1 || s[:2] != `%"` {
return "", s, false
}
i := 2
for i < len(s) {
ch := s[i]
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
switch ch {
case '"':
if runeLen > 0 {
return "", s, false
}
return s[:i+1], s[i+1:], true
case '%':
if i+2 >= len(s) {
return "", s, false
}
if ch, ok = decOctetHex(s[i+1], s[i+2]); !ok {
return "", s, ok
}
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i += 3
default:
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i++
}
}
return "", s, false
}
// ParseDisplayString parses a display string from a given HTTP Structured
// Field Values.
//
// The entire HTTP SFV string must consist of a valid display string. It
// returns the parsed display string and an ok boolean value, indicating
// success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func ParseDisplayString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeDisplayString(s); !ok || rest != "" {
return "", false
}
// consumeDisplayString() already validates that we have a valid display
// string. Therefore, we can just construct the display string, without
// validating it again.
s = s[2 : len(s)-1]
var b strings.Builder
for i := 0; i < len(s); {
if s[i] == '%' {
decoded, _ := decOctetHex(s[i+1], s[i+2])
b.WriteByte(decoded)
i += 3
continue
}
b.WriteByte(s[i])
i++
}
return b.String(), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-bare-item.
func consumeBareItem(s string) (consumed, rest string, ok bool) {
if len(s) == 0 {
return "", s, false
}
ch := s[0]
switch {
case ch == '-' || isDigit(ch):
return consumeIntegerOrDecimal(s)
case ch == '"':
return consumeString(s)
case ch == '*' || isAlpha(ch):
return consumeToken(s)
case ch == ':':
return consumeByteSequence(s)
case ch == '?':
return consumeBoolean(s)
case ch == '@':
return consumeDate(s)
case ch == '%':
return consumeDisplayString(s)
default:
return "", s, false
}
}

View file

@ -6,7 +6,10 @@
package socket package socket
import "unsafe" import (
"encoding/binary"
"unsafe"
)
func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) { func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
for i := range vs { for i := range vs {
@ -31,5 +34,5 @@ func (h *msghdr) controllen() int {
} }
func (h *msghdr) flags() int { func (h *msghdr) flags() int {
return int(NativeEndian.Uint32(h.Pad_cgo_2[:])) return int(binary.NativeEndian.Uint32(h.Pad_cgo_2[:]))
} }

View file

@ -7,6 +7,7 @@
package socket // import "golang.org/x/net/internal/socket" package socket // import "golang.org/x/net/internal/socket"
import ( import (
"encoding/binary"
"errors" "errors"
"net" "net"
"runtime" "runtime"
@ -58,7 +59,7 @@ func (o *Option) GetInt(c *Conn) (int, error) {
if o.Len == 1 { if o.Len == 1 {
return int(b[0]), nil return int(b[0]), nil
} }
return int(NativeEndian.Uint32(b[:4])), nil return int(binary.NativeEndian.Uint32(b[:4])), nil
} }
// Set writes the option and value to the kernel. // Set writes the option and value to the kernel.
@ -84,7 +85,7 @@ func (o *Option) SetInt(c *Conn, v int) error {
b = []byte{byte(v)} b = []byte{byte(v)}
} else { } else {
var bb [4]byte var bb [4]byte
NativeEndian.PutUint32(bb[:o.Len], uint32(v)) binary.NativeEndian.PutUint32(bb[:o.Len], uint32(v))
b = bb[:4] b = bb[:4]
} }
return o.set(c, b) return o.set(c, b)

View file

@ -1,23 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socket
import (
"encoding/binary"
"unsafe"
)
// NativeEndian is the machine native endian implementation of ByteOrder.
var NativeEndian binary.ByteOrder
func init() {
i := uint32(1)
b := (*[4]byte)(unsafe.Pointer(&i))
if b[0] == 1 {
NativeEndian = binary.LittleEndian
} else {
NativeEndian = binary.BigEndian
}
}

View file

@ -36,7 +36,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
switch runtime.GOOS { switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows": case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET)) binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
default: default:
b[0] = sizeofSockaddrInet4 b[0] = sizeofSockaddrInet4
b[1] = sysAF_INET b[1] = sysAF_INET
@ -48,7 +48,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil { if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
switch runtime.GOOS { switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows": case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6)) binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
default: default:
b[0] = sizeofSockaddrInet6 b[0] = sizeofSockaddrInet6
b[1] = sysAF_INET6 b[1] = sysAF_INET6
@ -56,7 +56,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
binary.BigEndian.PutUint16(b[2:4], uint16(port)) binary.BigEndian.PutUint16(b[2:4], uint16(port))
copy(b[8:24], ip6) copy(b[8:24], ip6)
if zone != "" { if zone != "" {
NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone))) binary.NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
} }
return sizeofSockaddrInet6 return sizeofSockaddrInet6
} }
@ -70,7 +70,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
var af int var af int
switch runtime.GOOS { switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows": case "android", "illumos", "linux", "solaris", "windows":
af = int(NativeEndian.Uint16(b[:2])) af = int(binary.NativeEndian.Uint16(b[:2]))
default: default:
af = int(b[1]) af = int(b[1])
} }
@ -89,7 +89,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
} }
ip = make(net.IP, net.IPv6len) ip = make(net.IP, net.IPv6len)
copy(ip, b[8:24]) copy(ip, b[8:24])
if id := int(NativeEndian.Uint32(b[24:28])); id > 0 { if id := int(binary.NativeEndian.Uint32(b[24:28])); id > 0 {
zone = zoneCache.name(id) zone = zoneCache.name(id)
} }
} }

View file

@ -4,27 +4,27 @@
package socket package socket
type iovec struct { type iovec struct {
Base *byte Base *byte
Len uint64 Len uint64
} }
type msghdr struct { type msghdr struct {
Name *byte Name *byte
Namelen uint32 Namelen uint32
Iov *iovec Iov *iovec
Iovlen uint32 Iovlen uint32
Control *byte Control *byte
Controllen uint32 Controllen uint32
Flags int32 Flags int32
} }
type cmsghdr struct { type cmsghdr struct {
Len uint32 Len uint32
Level int32 Level int32
Type int32 Type int32
} }
const ( const (
sizeofIovec = 0x10 sizeofIovec = 0x10
sizeofMsghdr = 0x30 sizeofMsghdr = 0x30
) )

View file

@ -4,27 +4,27 @@
package socket package socket
type iovec struct { type iovec struct {
Base *byte Base *byte
Len uint64 Len uint64
} }
type msghdr struct { type msghdr struct {
Name *byte Name *byte
Namelen uint32 Namelen uint32
Iov *iovec Iov *iovec
Iovlen uint32 Iovlen uint32
Control *byte Control *byte
Controllen uint32 Controllen uint32
Flags int32 Flags int32
} }
type cmsghdr struct { type cmsghdr struct {
Len uint32 Len uint32
Level int32 Level int32
Type int32 Type int32
} }
const ( const (
sizeofIovec = 0x10 sizeofIovec = 0x10
sizeofMsghdr = 0x30 sizeofMsghdr = 0x30
) )

View file

@ -297,7 +297,7 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter,
b = append(b, up.Username...) b = append(b, up.Username...)
b = append(b, byte(len(up.Password))) b = append(b, byte(len(up.Password)))
b = append(b, up.Password...) b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if // TODO(mikio): handle IO deadlines and cancellation if
// necessary // necessary
if _, err := rw.Write(b); err != nil { if _, err := rw.Write(b); err != nil {
return err return err

Some files were not shown because too many files have changed in this diff Show more