Merge pull request #740 from fatedier/socks5

frpc: support connectiong frps by socks5 proxy
This commit is contained in:
fatedier 2018-05-04 19:15:11 +08:00 committed by GitHub
commit bebd1db22a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
151 changed files with 23619 additions and 14851 deletions

View File

@ -185,7 +185,7 @@ func (ctl *Control) login() (err error) {
ctl.session.Close() ctl.session.Close()
} }
conn, err := frpNet.ConnectServerByHttpProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, conn, err := frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort))
if err != nil { if err != nil {
return err return err
@ -253,7 +253,7 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) {
} }
conn = frpNet.WrapConn(stream) conn = frpNet.WrapConn(stream)
} else { } else {
conn, err = frpNet.ConnectServerByHttpProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort))
if err != nil { if err != nil {
ctl.Warn("start new connection to server error: %v", err) ctl.Warn("start new connection to server error: %v", err)

View File

@ -5,9 +5,10 @@
server_addr = 0.0.0.0 server_addr = 0.0.0.0
server_port = 7000 server_port = 7000
# if you want to connect frps by http proxy, you can set http_proxy here or in global environment variables # if you want to connect frps by http proxy or socks5 proxy, you can set http_proxy here or in global environment variables
# it only works when protocol is tcp # it only works when protocol is tcp
# http_proxy = http://user:passwd@192.168.1.128:8080 # http_proxy = http://user:passwd@192.168.1.128:8080
# http_proxy = socks5://user:passwd@192.168.1.128:1080
# console or real logFile path like ./frpc.log # console or real logFile path like ./frpc.log
log_file = ./frpc.log log_file = ./frpc.log

8
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: e2a62cbc49d9da8ff95682f5c0b7731a7047afdd139acddb691c51ea98f726e1 hash: 47d70fb6b7dee9b0e453269a7079b42488dae6c4902b4d5c93976a8b7e15f604
updated: 2018-04-25T02:41:38.15698+08:00 updated: 2018-05-04T17:59:35.698911+08:00
imports: imports:
- name: github.com/armon/go-socks5 - name: github.com/armon/go-socks5
version: e75332964ef517daa070d7c38a9466a0d687e0a5 version: e75332964ef517daa070d7c38a9466a0d687e0a5
@ -71,11 +71,13 @@ imports:
- twofish - twofish
- xtea - xtea
- name: golang.org/x/net - name: golang.org/x/net
version: e4fa1c5465ad6111f206fc92186b8c83d64adbe1 version: 640f4622ab692b87c2f3a94265e6f579fe38263d
subpackages: subpackages:
- bpf - bpf
- context - context
- internal/iana - internal/iana
- internal/socket - internal/socket
- internal/socks
- ipv4 - ipv4
- proxy
testImports: [] testImports: []

View File

@ -58,7 +58,7 @@ import:
- twofish - twofish
- xtea - xtea
- package: golang.org/x/net - package: golang.org/x/net
version: e4fa1c5465ad6111f206fc92186b8c83d64adbe1 version: 640f4622ab692b87c2f3a94265e6f579fe38263d
subpackages: subpackages:
- bpf - bpf
- context - context

View File

@ -279,7 +279,7 @@ func TestPluginHttpProxy(t *testing.T) {
} }
// connect method // connect method
conn, err := net.ConnectTcpServerByHttpProxy("http://"+addr, fmt.Sprintf("127.0.0.1:%d", TEST_TCP_FRP_PORT)) conn, err := net.ConnectTcpServerByProxy("http://"+addr, fmt.Sprintf("127.0.0.1:%d", TEST_TCP_FRP_PORT))
if assert.NoError(err) { if assert.NoError(err) {
res, err := sendTcpMsgByConn(conn, TEST_TCP_ECHO_STR) res, err := sendTcpMsgByConn(conn, TEST_TCP_ECHO_STR)
assert.NoError(err) assert.NoError(err)

View File

@ -122,10 +122,10 @@ func ConnectServer(protocol string, addr string) (c Conn, err error) {
} }
} }
func ConnectServerByHttpProxy(httpProxy string, protocol string, addr string) (c Conn, err error) { func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn, err error) {
switch protocol { switch protocol {
case "tcp": case "tcp":
return ConnectTcpServerByHttpProxy(httpProxy, addr) return ConnectTcpServerByProxy(proxyUrl, addr)
case "kcp": case "kcp":
// http proxy is not supported for kcp // http proxy is not supported for kcp
return ConnectServer(protocol, addr) return ConnectServer(protocol, addr)

View File

@ -23,6 +23,8 @@ import (
"net/url" "net/url"
"github.com/fatedier/frp/utils/log" "github.com/fatedier/frp/utils/log"
"golang.org/x/net/proxy"
) )
type TcpListener struct { type TcpListener struct {
@ -93,7 +95,7 @@ type TcpConn struct {
log.Logger log.Logger
} }
func NewTcpConn(conn *net.TCPConn) (c *TcpConn) { func NewTcpConn(conn net.Conn) (c *TcpConn) {
c = &TcpConn{ c = &TcpConn{
Conn: conn, Conn: conn,
Logger: log.NewPrefixLogger(""), Logger: log.NewPrefixLogger(""),
@ -114,28 +116,41 @@ func ConnectTcpServer(addr string) (c Conn, err error) {
return return
} }
// ConnectTcpServerByHttpProxy try to connect remote server by http proxy. // ConnectTcpServerByProxy try to connect remote server by proxy.
// If httpProxy is empty, it will connect server directly. func ConnectTcpServerByProxy(proxyStr string, serverAddr string) (c Conn, err error) {
func ConnectTcpServerByHttpProxy(httpProxy string, serverAddr string) (c Conn, err error) { if proxyStr == "" {
if httpProxy == "" {
return ConnectTcpServer(serverAddr) return ConnectTcpServer(serverAddr)
} }
var proxyUrl *url.URL var (
if proxyUrl, err = url.Parse(httpProxy); err != nil { proxyUrl *url.URL
username string
passwd string
)
if proxyUrl, err = url.Parse(proxyStr); err != nil {
return return
} }
if proxyUrl.User != nil {
username = proxyUrl.User.Username()
passwd, _ = proxyUrl.User.Password()
}
switch proxyUrl.Scheme {
case "http":
return ConnectTcpServerByHttpProxy(proxyUrl, username, passwd, serverAddr)
case "socks5":
return ConnectTcpServerBySocks5Proxy(proxyUrl, username, passwd, serverAddr)
default:
err = fmt.Errorf("Proxy URL scheme must be http or socks5, not [%s]", proxyUrl.Scheme)
return
}
}
// ConnectTcpServerByHttpProxy try to connect remote server by http proxy.
func ConnectTcpServerByHttpProxy(proxyUrl *url.URL, user string, passwd string, serverAddr string) (c Conn, err error) {
var proxyAuth string var proxyAuth string
if proxyUrl.User != nil { if proxyUrl.User != nil {
username := proxyUrl.User.Username() proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+passwd))
passwd, _ := proxyUrl.User.Password()
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+passwd))
}
if proxyUrl.Scheme != "http" {
err = fmt.Errorf("Proxy URL scheme must be http, not [%s]", proxyUrl.Scheme)
return
} }
if c, err = ConnectTcpServer(proxyUrl.Host); err != nil { if c, err = ConnectTcpServer(proxyUrl.Host); err != nil {
@ -161,6 +176,27 @@ func ConnectTcpServerByHttpProxy(httpProxy string, serverAddr string) (c Conn, e
err = fmt.Errorf("ConnectTcpServer using proxy error, StatusCode [%d]", resp.StatusCode) err = fmt.Errorf("ConnectTcpServer using proxy error, StatusCode [%d]", resp.StatusCode)
return return
} }
return
}
func ConnectTcpServerBySocks5Proxy(proxyUrl *url.URL, user string, passwd string, serverAddr string) (c Conn, err error) {
var auth *proxy.Auth
if proxyUrl.User != nil {
auth = &proxy.Auth{
User: user,
Password: passwd,
}
}
dialer, err := proxy.SOCKS5("tcp", proxyUrl.Host, auth, nil)
if err != nil {
return nil, err
}
var conn net.Conn
if conn, err = dialer.Dial("tcp", serverAddr); err != nil {
return
}
c = NewTcpConn(conn)
return return
} }

View File

@ -4,16 +4,15 @@ Go is an open source project.
It is the work of hundreds of contributors. We appreciate your help! It is the work of hundreds of contributors. We appreciate your help!
## Filing issues ## Filing issues
When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions: When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions:
1. What version of Go are you using (`go version`)? 1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using? 2. What operating system and processor architecture are you using?
3. What did you do? 3. What did you do?
4. What did you expect to see? 4. What did you expect to see?
5. What did you see instead? 5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug. The gophers there will answer or ask you to file an issue if you've tripped over a bug.
@ -23,9 +22,5 @@ The gophers there will answer or ask you to file an issue if you've tripped over
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches. before sending patches.
**We do not accept GitHub pull requests**
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
Unless otherwise noted, the Go source files are distributed under Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file. the BSD-style license found in the LICENSE file.

3
vendor/golang.org/x/net/README generated vendored
View File

@ -1,3 +0,0 @@
This repository holds supplementary Go networking libraries.
To submit changes to this repository, see http://golang.org/doc/contribute.html.

16
vendor/golang.org/x/net/README.md generated vendored Normal file
View File

@ -0,0 +1,16 @@
# Go Networking
This repository holds supplementary Go networking libraries.
## Download/Install
The easiest way to install is to run `go get -u golang.org/x/net`. You can
also manually git clone the repository to `$GOPATH/src/golang.org/x/net`.
## Report Issues / Send Patches
This repository uses Gerrit for code changes. To learn how to submit
changes to this repository, see https://golang.org/doc/contribute.html.
The main issue tracker for the net repository is located at
https://github.com/golang/go/issues. Prefix your issue with "x/net:" in the
subject line, so it is easy to find.

View File

@ -198,7 +198,7 @@ func (a LoadConstant) Assemble() (RawInstruction, error) {
return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val) return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val)
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadConstant) String() string { func (a LoadConstant) String() string {
switch a.Dst { switch a.Dst {
case RegA: case RegA:
@ -224,7 +224,7 @@ func (a LoadScratch) Assemble() (RawInstruction, error) {
return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N)) return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N))
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadScratch) String() string { func (a LoadScratch) String() string {
switch a.Dst { switch a.Dst {
case RegA: case RegA:
@ -248,7 +248,7 @@ func (a LoadAbsolute) Assemble() (RawInstruction, error) {
return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off) return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off)
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadAbsolute) String() string { func (a LoadAbsolute) String() string {
switch a.Size { switch a.Size {
case 1: // byte case 1: // byte
@ -277,7 +277,7 @@ func (a LoadIndirect) Assemble() (RawInstruction, error) {
return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off) return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off)
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadIndirect) String() string { func (a LoadIndirect) String() string {
switch a.Size { switch a.Size {
case 1: // byte case 1: // byte
@ -306,7 +306,7 @@ func (a LoadMemShift) Assemble() (RawInstruction, error) {
return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off) return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off)
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadMemShift) String() string { func (a LoadMemShift) String() string {
return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off) return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off)
} }
@ -325,7 +325,7 @@ func (a LoadExtension) Assemble() (RawInstruction, error) {
return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num)) return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num))
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a LoadExtension) String() string { func (a LoadExtension) String() string {
switch a.Num { switch a.Num {
case ExtLen: case ExtLen:
@ -392,7 +392,7 @@ func (a StoreScratch) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a StoreScratch) String() string { func (a StoreScratch) String() string {
switch a.Src { switch a.Src {
case RegA: case RegA:
@ -418,7 +418,7 @@ func (a ALUOpConstant) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a ALUOpConstant) String() string { func (a ALUOpConstant) String() string {
switch a.Op { switch a.Op {
case ALUOpAdd: case ALUOpAdd:
@ -458,7 +458,7 @@ func (a ALUOpX) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a ALUOpX) String() string { func (a ALUOpX) String() string {
switch a.Op { switch a.Op {
case ALUOpAdd: case ALUOpAdd:
@ -496,7 +496,7 @@ func (a NegateA) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a NegateA) String() string { func (a NegateA) String() string {
return fmt.Sprintf("neg") return fmt.Sprintf("neg")
} }
@ -514,7 +514,7 @@ func (a Jump) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a Jump) String() string { func (a Jump) String() string {
return fmt.Sprintf("ja %d", a.Skip) return fmt.Sprintf("ja %d", a.Skip)
} }
@ -566,7 +566,7 @@ func (a JumpIf) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a JumpIf) String() string { func (a JumpIf) String() string {
switch a.Cond { switch a.Cond {
// K == A // K == A
@ -621,7 +621,7 @@ func (a RetA) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a RetA) String() string { func (a RetA) String() string {
return fmt.Sprintf("ret a") return fmt.Sprintf("ret a")
} }
@ -639,7 +639,7 @@ func (a RetConstant) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a RetConstant) String() string { func (a RetConstant) String() string {
return fmt.Sprintf("ret #%d", a.Val) return fmt.Sprintf("ret #%d", a.Val)
} }
@ -654,7 +654,7 @@ func (a TXA) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a TXA) String() string { func (a TXA) String() string {
return fmt.Sprintf("txa") return fmt.Sprintf("txa")
} }
@ -669,7 +669,7 @@ func (a TAX) Assemble() (RawInstruction, error) {
}, nil }, nil
} }
// String returns the the instruction in assembler notation. // String returns the instruction in assembler notation.
func (a TAX) String() string { func (a TAX) String() string {
return fmt.Sprintf("tax") return fmt.Sprintf("tax")
} }

10
vendor/golang.org/x/net/bpf/setter.go generated vendored Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bpf
// A Setter is a type which can attach a compiled BPF filter to itself.
type Setter interface {
SetBPF(filter []RawInstruction) error
}

View File

@ -5,6 +5,8 @@
// Package context defines the Context type, which carries deadlines, // Package context defines the Context type, which carries deadlines,
// cancelation signals, and other request-scoped values across API boundaries // cancelation signals, and other request-scoped values across API boundaries
// and between processes. // and between processes.
// As of Go 1.7 this package is available in the standard library under the
// name context. https://golang.org/pkg/context.
// //
// Incoming requests to a server should create a Context, and outgoing calls to // Incoming requests to a server should create a Context, and outgoing calls to
// servers should accept a Context. The chain of function calls between must // servers should accept a Context. The chain of function calls between must
@ -36,103 +38,6 @@
// Contexts. // Contexts.
package context // import "golang.org/x/net/context" package context // import "golang.org/x/net/context"
import "time"
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out chan<- Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See http://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancelation.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stores using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "golang.org/x/net/context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key = 0
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key interface{}) interface{}
}
// Background returns a non-nil, empty Context. It is never canceled, has no // Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function, // values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming // initialization, and tests, and as the top-level Context for incoming
@ -149,8 +54,3 @@ func Background() Context {
func TODO() Context { func TODO() Context {
return todo return todo
} }
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()

20
vendor/golang.org/x/net/context/go19.go generated vendored Normal file
View File

@ -0,0 +1,20 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.9
package context
import "context" // standard library's context, as of Go 1.7
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context = context.Context
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc = context.CancelFunc

109
vendor/golang.org/x/net/context/pre_go19.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.9
package context
import "time"
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out chan<- Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See http://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancelation.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stores using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "golang.org/x/net/context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key = 0
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key interface{}) interface{}
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()

View File

@ -11,16 +11,21 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// This example passes a context with a timeout to tell a blocking function that
// it should abandon its work after the timeout elapses.
func ExampleWithTimeout() { func ExampleWithTimeout() {
// Pass a context with a timeout to tell a blocking function that it // Pass a context with a timeout to tell a blocking function that it
// should abandon its work after the timeout elapses. // should abandon its work after the timeout elapses.
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
select { select {
case <-time.After(200 * time.Millisecond): case <-time.After(1 * time.Second):
fmt.Println("overslept") fmt.Println("overslept")
case <-ctx.Done(): case <-ctx.Done():
fmt.Println(ctx.Err()) // prints "context deadline exceeded" fmt.Println(ctx.Err()) // prints "context deadline exceeded"
} }
// Output: // Output:
// context deadline exceeded // context deadline exceeded
} }

132
vendor/golang.org/x/net/dns/dnsmessage/example_test.go generated vendored Normal file
View File

@ -0,0 +1,132 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnsmessage_test
import (
"fmt"
"net"
"strings"
"golang.org/x/net/dns/dnsmessage"
)
func mustNewName(name string) dnsmessage.Name {
n, err := dnsmessage.NewName(name)
if err != nil {
panic(err)
}
return n
}
func ExampleParser() {
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true},
Questions: []dnsmessage.Question{
{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
Answers: []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("foo.bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
},
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("bar.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}},
},
},
}
buf, err := msg.Pack()
if err != nil {
panic(err)
}
wantName := "bar.example.com."
var p dnsmessage.Parser
if _, err := p.Start(buf); err != nil {
panic(err)
}
for {
q, err := p.Question()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if q.Name.String() != wantName {
continue
}
fmt.Println("Found question for name", wantName)
if err := p.SkipAllQuestions(); err != nil {
panic(err)
}
break
}
var gotIPs []net.IP
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
panic(err)
}
if (h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA) || h.Class != dnsmessage.ClassINET {
continue
}
if !strings.EqualFold(h.Name.String(), wantName) {
if err := p.SkipAnswer(); err != nil {
panic(err)
}
continue
}
switch h.Type {
case dnsmessage.TypeA:
r, err := p.AResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.A[:])
case dnsmessage.TypeAAAA:
r, err := p.AAAAResource()
if err != nil {
panic(err)
}
gotIPs = append(gotIPs, r.AAAA[:])
}
}
fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
// Output:
// Found question for name bar.example.com.
// Found A/AAAA records for name bar.example.com.: [127.0.0.2]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,17 +4,17 @@
// +build ignore // +build ignore
//go:generate go run gen.go
//go:generate go run gen.go -test
package main package main
// This program generates table.go and table_test.go.
// Invoke as
//
// go run gen.go |gofmt >table.go
// go run gen.go -test |gofmt >table_test.go
import ( import (
"bytes"
"flag" "flag"
"fmt" "fmt"
"go/format"
"io/ioutil"
"math/rand" "math/rand"
"os" "os"
"sort" "sort"
@ -42,6 +42,18 @@ func identifier(s string) string {
var test = flag.Bool("test", false, "generate table_test.go") var test = flag.Bool("test", false, "generate table_test.go")
func genFile(name string, buf *bytes.Buffer) {
b, err := format.Source(buf.Bytes())
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
if err := ioutil.WriteFile(name, b, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func main() { func main() {
flag.Parse() flag.Parse()
@ -52,32 +64,31 @@ func main() {
all = append(all, extra...) all = append(all, extra...)
sort.Strings(all) sort.Strings(all)
if *test {
fmt.Printf("// generated by go run gen.go -test; DO NOT EDIT\n\n")
fmt.Printf("package atom\n\n")
fmt.Printf("var testAtomList = []string{\n")
for _, s := range all {
fmt.Printf("\t%q,\n", s)
}
fmt.Printf("}\n")
return
}
// uniq - lists have dups // uniq - lists have dups
// compute max len too
maxLen := 0
w := 0 w := 0
for _, s := range all { for _, s := range all {
if w == 0 || all[w-1] != s { if w == 0 || all[w-1] != s {
if maxLen < len(s) {
maxLen = len(s)
}
all[w] = s all[w] = s
w++ w++
} }
} }
all = all[:w] all = all[:w]
if *test {
var buf bytes.Buffer
fmt.Fprintln(&buf, "// Code generated by go generate gen.go; DO NOT EDIT.\n")
fmt.Fprintln(&buf, "//go:generate go run gen.go -test\n")
fmt.Fprintln(&buf, "package atom\n")
fmt.Fprintln(&buf, "var testAtomList = []string{")
for _, s := range all {
fmt.Fprintf(&buf, "\t%q,\n", s)
}
fmt.Fprintln(&buf, "}")
genFile("table_test.go", &buf)
return
}
// Find hash that minimizes table size. // Find hash that minimizes table size.
var best *table var best *table
for i := 0; i < 1000000; i++ { for i := 0; i < 1000000; i++ {
@ -163,36 +174,46 @@ func main() {
atom[s] = uint32(off<<8 | len(s)) atom[s] = uint32(off<<8 | len(s))
} }
var buf bytes.Buffer
// Generate the Go code. // Generate the Go code.
fmt.Printf("// generated by go run gen.go; DO NOT EDIT\n\n") fmt.Fprintln(&buf, "// Code generated by go generate gen.go; DO NOT EDIT.\n")
fmt.Printf("package atom\n\nconst (\n") fmt.Fprintln(&buf, "//go:generate go run gen.go\n")
fmt.Fprintln(&buf, "package atom\n\nconst (")
// compute max len
maxLen := 0
for _, s := range all { for _, s := range all {
fmt.Printf("\t%s Atom = %#x\n", identifier(s), atom[s]) if maxLen < len(s) {
maxLen = len(s)
}
fmt.Fprintf(&buf, "\t%s Atom = %#x\n", identifier(s), atom[s])
} }
fmt.Printf(")\n\n") fmt.Fprintln(&buf, ")\n")
fmt.Printf("const hash0 = %#x\n\n", best.h0) fmt.Fprintf(&buf, "const hash0 = %#x\n\n", best.h0)
fmt.Printf("const maxAtomLen = %d\n\n", maxLen) fmt.Fprintf(&buf, "const maxAtomLen = %d\n\n", maxLen)
fmt.Printf("var table = [1<<%d]Atom{\n", best.k) fmt.Fprintf(&buf, "var table = [1<<%d]Atom{\n", best.k)
for i, s := range best.tab { for i, s := range best.tab {
if s == "" { if s == "" {
continue continue
} }
fmt.Printf("\t%#x: %#x, // %s\n", i, atom[s], s) fmt.Fprintf(&buf, "\t%#x: %#x, // %s\n", i, atom[s], s)
} }
fmt.Printf("}\n") fmt.Fprintf(&buf, "}\n")
datasize := (1 << best.k) * 4 datasize := (1 << best.k) * 4
fmt.Printf("const atomText =\n") fmt.Fprintln(&buf, "const atomText =")
textsize := len(text) textsize := len(text)
for len(text) > 60 { for len(text) > 60 {
fmt.Printf("\t%q +\n", text[:60]) fmt.Fprintf(&buf, "\t%q +\n", text[:60])
text = text[60:] text = text[60:]
} }
fmt.Printf("\t%q\n\n", text) fmt.Fprintf(&buf, "\t%q\n\n", text)
fmt.Fprintf(os.Stderr, "%d atoms; %d string bytes + %d tables = %d total data\n", len(all), textsize, datasize, textsize+datasize) genFile("table.go", &buf)
fmt.Fprintf(os.Stdout, "%d atoms; %d string bytes + %d tables = %d total data\n", len(all), textsize, datasize, textsize+datasize)
} }
type byLen []string type byLen []string
@ -285,8 +306,10 @@ func (t *table) push(i uint32, depth int) bool {
// The lists of element names and attribute keys were taken from // The lists of element names and attribute keys were taken from
// https://html.spec.whatwg.org/multipage/indices.html#index // https://html.spec.whatwg.org/multipage/indices.html#index
// as of the "HTML Living Standard - Last Updated 21 February 2015" version. // as of the "HTML Living Standard - Last Updated 16 April 2018" version.
// "command", "keygen" and "menuitem" have been removed from the spec,
// but are kept here for backwards compatibility.
var elements = []string{ var elements = []string{
"a", "a",
"abbr", "abbr",
@ -349,6 +372,7 @@ var elements = []string{
"legend", "legend",
"li", "li",
"link", "link",
"main",
"map", "map",
"mark", "mark",
"menu", "menu",
@ -364,6 +388,7 @@ var elements = []string{
"output", "output",
"p", "p",
"param", "param",
"picture",
"pre", "pre",
"progress", "progress",
"q", "q",
@ -375,6 +400,7 @@ var elements = []string{
"script", "script",
"section", "section",
"select", "select",
"slot",
"small", "small",
"source", "source",
"span", "span",
@ -403,14 +429,21 @@ var elements = []string{
} }
// https://html.spec.whatwg.org/multipage/indices.html#attributes-3 // https://html.spec.whatwg.org/multipage/indices.html#attributes-3
//
// "challenge", "command", "contextmenu", "dropzone", "icon", "keytype", "mediagroup",
// "radiogroup", "spellcheck", "scoped", "seamless", "sortable" and "sorted" have been removed from the spec,
// but are kept here for backwards compatibility.
var attributes = []string{ var attributes = []string{
"abbr", "abbr",
"accept", "accept",
"accept-charset", "accept-charset",
"accesskey", "accesskey",
"action", "action",
"allowfullscreen",
"allowpaymentrequest",
"allowusermedia",
"alt", "alt",
"as",
"async", "async",
"autocomplete", "autocomplete",
"autofocus", "autofocus",
@ -420,6 +453,7 @@ var attributes = []string{
"checked", "checked",
"cite", "cite",
"class", "class",
"color",
"cols", "cols",
"colspan", "colspan",
"command", "command",
@ -457,6 +491,8 @@ var attributes = []string{
"icon", "icon",
"id", "id",
"inputmode", "inputmode",
"integrity",
"is",
"ismap", "ismap",
"itemid", "itemid",
"itemprop", "itemprop",
@ -481,16 +517,20 @@ var attributes = []string{
"multiple", "multiple",
"muted", "muted",
"name", "name",
"nomodule",
"nonce",
"novalidate", "novalidate",
"open", "open",
"optimum", "optimum",
"pattern", "pattern",
"ping", "ping",
"placeholder", "placeholder",
"playsinline",
"poster", "poster",
"preload", "preload",
"radiogroup", "radiogroup",
"readonly", "readonly",
"referrerpolicy",
"rel", "rel",
"required", "required",
"reversed", "reversed",
@ -507,10 +547,13 @@ var attributes = []string{
"sizes", "sizes",
"sortable", "sortable",
"sorted", "sorted",
"slot",
"span", "span",
"spellcheck",
"src", "src",
"srcdoc", "srcdoc",
"srclang", "srclang",
"srcset",
"start", "start",
"step", "step",
"style", "style",
@ -520,16 +563,22 @@ var attributes = []string{
"translate", "translate",
"type", "type",
"typemustmatch", "typemustmatch",
"updateviacache",
"usemap", "usemap",
"value", "value",
"width", "width",
"workertype",
"wrap", "wrap",
} }
// "onautocomplete", "onautocompleteerror", "onmousewheel",
// "onshow" and "onsort" have been removed from the spec,
// but are kept here for backwards compatibility.
var eventHandlers = []string{ var eventHandlers = []string{
"onabort", "onabort",
"onautocomplete", "onautocomplete",
"onautocompleteerror", "onautocompleteerror",
"onauxclick",
"onafterprint", "onafterprint",
"onbeforeprint", "onbeforeprint",
"onbeforeunload", "onbeforeunload",
@ -541,11 +590,14 @@ var eventHandlers = []string{
"onclick", "onclick",
"onclose", "onclose",
"oncontextmenu", "oncontextmenu",
"oncopy",
"oncuechange", "oncuechange",
"oncut",
"ondblclick", "ondblclick",
"ondrag", "ondrag",
"ondragend", "ondragend",
"ondragenter", "ondragenter",
"ondragexit",
"ondragleave", "ondragleave",
"ondragover", "ondragover",
"ondragstart", "ondragstart",
@ -565,18 +617,24 @@ var eventHandlers = []string{
"onload", "onload",
"onloadeddata", "onloadeddata",
"onloadedmetadata", "onloadedmetadata",
"onloadend",
"onloadstart", "onloadstart",
"onmessage", "onmessage",
"onmessageerror",
"onmousedown", "onmousedown",
"onmouseenter",
"onmouseleave",
"onmousemove", "onmousemove",
"onmouseout", "onmouseout",
"onmouseover", "onmouseover",
"onmouseup", "onmouseup",
"onmousewheel", "onmousewheel",
"onwheel",
"onoffline", "onoffline",
"ononline", "ononline",
"onpagehide", "onpagehide",
"onpageshow", "onpageshow",
"onpaste",
"onpause", "onpause",
"onplay", "onplay",
"onplaying", "onplaying",
@ -585,7 +643,9 @@ var eventHandlers = []string{
"onratechange", "onratechange",
"onreset", "onreset",
"onresize", "onresize",
"onrejectionhandled",
"onscroll", "onscroll",
"onsecuritypolicyviolation",
"onseeked", "onseeked",
"onseeking", "onseeking",
"onselect", "onselect",
@ -597,6 +657,7 @@ var eventHandlers = []string{
"onsuspend", "onsuspend",
"ontimeupdate", "ontimeupdate",
"ontoggle", "ontoggle",
"onunhandledrejection",
"onunload", "onunload",
"onvolumechange", "onvolumechange",
"onwaiting", "onwaiting",
@ -604,6 +665,7 @@ var eventHandlers = []string{
// extra are ad-hoc values not covered by any of the lists above. // extra are ad-hoc values not covered by any of the lists above.
var extra = []string{ var extra = []string{
"acronym",
"align", "align",
"annotation", "annotation",
"annotation-xml", "annotation-xml",
@ -639,6 +701,8 @@ var extra = []string{
"plaintext", "plaintext",
"prompt", "prompt",
"public", "public",
"rb",
"rtc",
"spacer", "spacer",
"strike", "strike",
"svg", "svg",

File diff suppressed because it is too large Load Diff

View File

@ -1,23 +1,29 @@
// generated by go run gen.go -test; DO NOT EDIT // Code generated by go generate gen.go; DO NOT EDIT.
//go:generate go run gen.go -test
package atom package atom
var testAtomList = []string{ var testAtomList = []string{
"a", "a",
"abbr", "abbr",
"abbr",
"accept", "accept",
"accept-charset", "accept-charset",
"accesskey", "accesskey",
"acronym",
"action", "action",
"address", "address",
"align", "align",
"allowfullscreen",
"allowpaymentrequest",
"allowusermedia",
"alt", "alt",
"annotation", "annotation",
"annotation-xml", "annotation-xml",
"applet", "applet",
"area", "area",
"article", "article",
"as",
"aside", "aside",
"async", "async",
"audio", "audio",
@ -43,7 +49,6 @@ var testAtomList = []string{
"charset", "charset",
"checked", "checked",
"cite", "cite",
"cite",
"class", "class",
"code", "code",
"col", "col",
@ -52,7 +57,6 @@ var testAtomList = []string{
"cols", "cols",
"colspan", "colspan",
"command", "command",
"command",
"content", "content",
"contenteditable", "contenteditable",
"contextmenu", "contextmenu",
@ -60,7 +64,6 @@ var testAtomList = []string{
"coords", "coords",
"crossorigin", "crossorigin",
"data", "data",
"data",
"datalist", "datalist",
"datetime", "datetime",
"dd", "dd",
@ -93,7 +96,6 @@ var testAtomList = []string{
"foreignObject", "foreignObject",
"foreignobject", "foreignobject",
"form", "form",
"form",
"formaction", "formaction",
"formenctype", "formenctype",
"formmethod", "formmethod",
@ -128,6 +130,8 @@ var testAtomList = []string{
"input", "input",
"inputmode", "inputmode",
"ins", "ins",
"integrity",
"is",
"isindex", "isindex",
"ismap", "ismap",
"itemid", "itemid",
@ -140,7 +144,6 @@ var testAtomList = []string{
"keytype", "keytype",
"kind", "kind",
"label", "label",
"label",
"lang", "lang",
"legend", "legend",
"li", "li",
@ -149,6 +152,7 @@ var testAtomList = []string{
"listing", "listing",
"loop", "loop",
"low", "low",
"main",
"malignmark", "malignmark",
"manifest", "manifest",
"map", "map",
@ -179,6 +183,8 @@ var testAtomList = []string{
"nobr", "nobr",
"noembed", "noembed",
"noframes", "noframes",
"nomodule",
"nonce",
"noscript", "noscript",
"novalidate", "novalidate",
"object", "object",
@ -187,6 +193,7 @@ var testAtomList = []string{
"onafterprint", "onafterprint",
"onautocomplete", "onautocomplete",
"onautocompleteerror", "onautocompleteerror",
"onauxclick",
"onbeforeprint", "onbeforeprint",
"onbeforeunload", "onbeforeunload",
"onblur", "onblur",
@ -197,11 +204,14 @@ var testAtomList = []string{
"onclick", "onclick",
"onclose", "onclose",
"oncontextmenu", "oncontextmenu",
"oncopy",
"oncuechange", "oncuechange",
"oncut",
"ondblclick", "ondblclick",
"ondrag", "ondrag",
"ondragend", "ondragend",
"ondragenter", "ondragenter",
"ondragexit",
"ondragleave", "ondragleave",
"ondragover", "ondragover",
"ondragstart", "ondragstart",
@ -221,9 +231,13 @@ var testAtomList = []string{
"onload", "onload",
"onloadeddata", "onloadeddata",
"onloadedmetadata", "onloadedmetadata",
"onloadend",
"onloadstart", "onloadstart",
"onmessage", "onmessage",
"onmessageerror",
"onmousedown", "onmousedown",
"onmouseenter",
"onmouseleave",
"onmousemove", "onmousemove",
"onmouseout", "onmouseout",
"onmouseover", "onmouseover",
@ -233,15 +247,18 @@ var testAtomList = []string{
"ononline", "ononline",
"onpagehide", "onpagehide",
"onpageshow", "onpageshow",
"onpaste",
"onpause", "onpause",
"onplay", "onplay",
"onplaying", "onplaying",
"onpopstate", "onpopstate",
"onprogress", "onprogress",
"onratechange", "onratechange",
"onrejectionhandled",
"onreset", "onreset",
"onresize", "onresize",
"onscroll", "onscroll",
"onsecuritypolicyviolation",
"onseeked", "onseeked",
"onseeking", "onseeking",
"onselect", "onselect",
@ -253,9 +270,11 @@ var testAtomList = []string{
"onsuspend", "onsuspend",
"ontimeupdate", "ontimeupdate",
"ontoggle", "ontoggle",
"onunhandledrejection",
"onunload", "onunload",
"onvolumechange", "onvolumechange",
"onwaiting", "onwaiting",
"onwheel",
"open", "open",
"optgroup", "optgroup",
"optimum", "optimum",
@ -264,9 +283,11 @@ var testAtomList = []string{
"p", "p",
"param", "param",
"pattern", "pattern",
"picture",
"ping", "ping",
"placeholder", "placeholder",
"plaintext", "plaintext",
"playsinline",
"poster", "poster",
"pre", "pre",
"preload", "preload",
@ -275,7 +296,9 @@ var testAtomList = []string{
"public", "public",
"q", "q",
"radiogroup", "radiogroup",
"rb",
"readonly", "readonly",
"referrerpolicy",
"rel", "rel",
"required", "required",
"reversed", "reversed",
@ -283,6 +306,7 @@ var testAtomList = []string{
"rowspan", "rowspan",
"rp", "rp",
"rt", "rt",
"rtc",
"ruby", "ruby",
"s", "s",
"samp", "samp",
@ -297,23 +321,23 @@ var testAtomList = []string{
"shape", "shape",
"size", "size",
"sizes", "sizes",
"slot",
"small", "small",
"sortable", "sortable",
"sorted", "sorted",
"source", "source",
"spacer", "spacer",
"span", "span",
"span",
"spellcheck", "spellcheck",
"src", "src",
"srcdoc", "srcdoc",
"srclang", "srclang",
"srcset",
"start", "start",
"step", "step",
"strike", "strike",
"strong", "strong",
"style", "style",
"style",
"sub", "sub",
"summary", "summary",
"sup", "sup",
@ -331,7 +355,6 @@ var testAtomList = []string{
"thead", "thead",
"time", "time",
"title", "title",
"title",
"tr", "tr",
"track", "track",
"translate", "translate",
@ -340,12 +363,14 @@ var testAtomList = []string{
"typemustmatch", "typemustmatch",
"u", "u",
"ul", "ul",
"updateviacache",
"usemap", "usemap",
"value", "value",
"var", "var",
"video", "video",
"wbr", "wbr",
"width", "width",
"workertype",
"wrap", "wrap",
"xmp", "xmp",
} }

View File

@ -4,7 +4,7 @@
package html package html
// Section 12.2.3.2 of the HTML5 specification says "The following elements // Section 12.2.4.2 of the HTML5 specification says "The following elements
// have varying levels of special parsing rules". // have varying levels of special parsing rules".
// https://html.spec.whatwg.org/multipage/syntax.html#the-stack-of-open-elements // https://html.spec.whatwg.org/multipage/syntax.html#the-stack-of-open-elements
var isSpecialElementMap = map[string]bool{ var isSpecialElementMap = map[string]bool{
@ -52,10 +52,12 @@ var isSpecialElementMap = map[string]bool{
"iframe": true, "iframe": true,
"img": true, "img": true,
"input": true, "input": true,
"isindex": true, "isindex": true, // The 'isindex' element has been removed, but keep it for backwards compatibility.
"keygen": true,
"li": true, "li": true,
"link": true, "link": true,
"listing": true, "listing": true,
"main": true,
"marquee": true, "marquee": true,
"menu": true, "menu": true,
"meta": true, "meta": true,

View File

@ -49,18 +49,18 @@ call to Next. For example, to extract an HTML page's anchor text:
for { for {
tt := z.Next() tt := z.Next()
switch tt { switch tt {
case ErrorToken: case html.ErrorToken:
return z.Err() return z.Err()
case TextToken: case html.TextToken:
if depth > 0 { if depth > 0 {
// emitBytes should copy the []byte it receives, // emitBytes should copy the []byte it receives,
// if it doesn't process it immediately. // if it doesn't process it immediately.
emitBytes(z.Text()) emitBytes(z.Text())
} }
case StartTagToken, EndTagToken: case html.StartTagToken, html.EndTagToken:
tn, _ := z.TagName() tn, _ := z.TagName()
if len(tn) == 1 && tn[0] == 'a' { if len(tn) == 1 && tn[0] == 'a' {
if tt == StartTagToken { if tt == html.StartTagToken {
depth++ depth++
} else { } else {
depth-- depth--

View File

@ -67,7 +67,7 @@ func mathMLTextIntegrationPoint(n *Node) bool {
return false return false
} }
// Section 12.2.5.5. // Section 12.2.6.5.
var breakout = map[string]bool{ var breakout = map[string]bool{
"b": true, "b": true,
"big": true, "big": true,
@ -115,7 +115,7 @@ var breakout = map[string]bool{
"var": true, "var": true,
} }
// Section 12.2.5.5. // Section 12.2.6.5.
var svgTagNameAdjustments = map[string]string{ var svgTagNameAdjustments = map[string]string{
"altglyph": "altGlyph", "altglyph": "altGlyph",
"altglyphdef": "altGlyphDef", "altglyphdef": "altGlyphDef",
@ -155,7 +155,7 @@ var svgTagNameAdjustments = map[string]string{
"textpath": "textPath", "textpath": "textPath",
} }
// Section 12.2.5.1 // Section 12.2.6.1
var mathMLAttributeAdjustments = map[string]string{ var mathMLAttributeAdjustments = map[string]string{
"definitionurl": "definitionURL", "definitionurl": "definitionURL",
} }

33
vendor/golang.org/x/net/html/node.go generated vendored
View File

@ -21,9 +21,10 @@ const (
scopeMarkerNode scopeMarkerNode
) )
// Section 12.2.3.3 says "scope markers are inserted when entering applet // Section 12.2.4.3 says "The markers are inserted when entering applet,
// elements, buttons, object elements, marquees, table cells, and table // object, marquee, template, td, th, and caption elements, and are used
// captions, and are used to prevent formatting from 'leaking'". // to prevent formatting from "leaking" into applet, object, marquee,
// template, td, th, and caption elements".
var scopeMarker = Node{Type: scopeMarkerNode} var scopeMarker = Node{Type: scopeMarkerNode}
// A Node consists of a NodeType and some Data (tag name for element nodes, // A Node consists of a NodeType and some Data (tag name for element nodes,
@ -173,6 +174,16 @@ func (s *nodeStack) index(n *Node) int {
return -1 return -1
} }
// contains returns whether a is within s.
func (s *nodeStack) contains(a atom.Atom) bool {
for _, n := range *s {
if n.DataAtom == a {
return true
}
}
return false
}
// insert inserts a node at the given index. // insert inserts a node at the given index.
func (s *nodeStack) insert(i int, n *Node) { func (s *nodeStack) insert(i int, n *Node) {
(*s) = append(*s, nil) (*s) = append(*s, nil)
@ -191,3 +202,19 @@ func (s *nodeStack) remove(n *Node) {
(*s)[j] = nil (*s)[j] = nil
*s = (*s)[:j] *s = (*s)[:j]
} }
type insertionModeStack []insertionMode
func (s *insertionModeStack) pop() (im insertionMode) {
i := len(*s)
im = (*s)[i-1]
*s = (*s)[:i-1]
return im
}
func (s *insertionModeStack) top() insertionMode {
if i := len(*s); i > 0 {
return (*s)[i-1]
}
return nil
}

361
vendor/golang.org/x/net/html/parse.go generated vendored
View File

@ -25,20 +25,22 @@ type parser struct {
hasSelfClosingToken bool hasSelfClosingToken bool
// doc is the document root element. // doc is the document root element.
doc *Node doc *Node
// The stack of open elements (section 12.2.3.2) and active formatting // The stack of open elements (section 12.2.4.2) and active formatting
// elements (section 12.2.3.3). // elements (section 12.2.4.3).
oe, afe nodeStack oe, afe nodeStack
// Element pointers (section 12.2.3.4). // Element pointers (section 12.2.4.4).
head, form *Node head, form *Node
// Other parsing state flags (section 12.2.3.5). // Other parsing state flags (section 12.2.4.5).
scripting, framesetOK bool scripting, framesetOK bool
// The stack of template insertion modes
templateStack insertionModeStack
// im is the current insertion mode. // im is the current insertion mode.
im insertionMode im insertionMode
// originalIM is the insertion mode to go back to after completing a text // originalIM is the insertion mode to go back to after completing a text
// or inTableText insertion mode. // or inTableText insertion mode.
originalIM insertionMode originalIM insertionMode
// fosterParenting is whether new elements should be inserted according to // fosterParenting is whether new elements should be inserted according to
// the foster parenting rules (section 12.2.5.3). // the foster parenting rules (section 12.2.6.1).
fosterParenting bool fosterParenting bool
// quirks is whether the parser is operating in "quirks mode." // quirks is whether the parser is operating in "quirks mode."
quirks bool quirks bool
@ -56,7 +58,7 @@ func (p *parser) top() *Node {
return p.doc return p.doc
} }
// Stop tags for use in popUntil. These come from section 12.2.3.2. // Stop tags for use in popUntil. These come from section 12.2.4.2.
var ( var (
defaultScopeStopTags = map[string][]a.Atom{ defaultScopeStopTags = map[string][]a.Atom{
"": {a.Applet, a.Caption, a.Html, a.Table, a.Td, a.Th, a.Marquee, a.Object, a.Template}, "": {a.Applet, a.Caption, a.Html, a.Table, a.Td, a.Th, a.Marquee, a.Object, a.Template},
@ -79,7 +81,7 @@ const (
// popUntil pops the stack of open elements at the highest element whose tag // popUntil pops the stack of open elements at the highest element whose tag
// is in matchTags, provided there is no higher element in the scope's stop // is in matchTags, provided there is no higher element in the scope's stop
// tags (as defined in section 12.2.3.2). It returns whether or not there was // tags (as defined in section 12.2.4.2). It returns whether or not there was
// such an element. If there was not, popUntil leaves the stack unchanged. // such an element. If there was not, popUntil leaves the stack unchanged.
// //
// For example, the set of stop tags for table scope is: "html", "table". If // For example, the set of stop tags for table scope is: "html", "table". If
@ -126,7 +128,7 @@ func (p *parser) indexOfElementInScope(s scope, matchTags ...a.Atom) int {
return -1 return -1
} }
case tableScope: case tableScope:
if tagAtom == a.Html || tagAtom == a.Table { if tagAtom == a.Html || tagAtom == a.Table || tagAtom == a.Template {
return -1 return -1
} }
case selectScope: case selectScope:
@ -162,17 +164,17 @@ func (p *parser) clearStackToContext(s scope) {
tagAtom := p.oe[i].DataAtom tagAtom := p.oe[i].DataAtom
switch s { switch s {
case tableScope: case tableScope:
if tagAtom == a.Html || tagAtom == a.Table { if tagAtom == a.Html || tagAtom == a.Table || tagAtom == a.Template {
p.oe = p.oe[:i+1] p.oe = p.oe[:i+1]
return return
} }
case tableRowScope: case tableRowScope:
if tagAtom == a.Html || tagAtom == a.Tr { if tagAtom == a.Html || tagAtom == a.Tr || tagAtom == a.Template {
p.oe = p.oe[:i+1] p.oe = p.oe[:i+1]
return return
} }
case tableBodyScope: case tableBodyScope:
if tagAtom == a.Html || tagAtom == a.Tbody || tagAtom == a.Tfoot || tagAtom == a.Thead { if tagAtom == a.Html || tagAtom == a.Tbody || tagAtom == a.Tfoot || tagAtom == a.Thead || tagAtom == a.Template {
p.oe = p.oe[:i+1] p.oe = p.oe[:i+1]
return return
} }
@ -183,7 +185,7 @@ func (p *parser) clearStackToContext(s scope) {
} }
// generateImpliedEndTags pops nodes off the stack of open elements as long as // generateImpliedEndTags pops nodes off the stack of open elements as long as
// the top node has a tag name of dd, dt, li, option, optgroup, p, rp, or rt. // the top node has a tag name of dd, dt, li, optgroup, option, p, rb, rp, rt or rtc.
// If exceptions are specified, nodes with that name will not be popped off. // If exceptions are specified, nodes with that name will not be popped off.
func (p *parser) generateImpliedEndTags(exceptions ...string) { func (p *parser) generateImpliedEndTags(exceptions ...string) {
var i int var i int
@ -192,7 +194,7 @@ loop:
n := p.oe[i] n := p.oe[i]
if n.Type == ElementNode { if n.Type == ElementNode {
switch n.DataAtom { switch n.DataAtom {
case a.Dd, a.Dt, a.Li, a.Option, a.Optgroup, a.P, a.Rp, a.Rt: case a.Dd, a.Dt, a.Li, a.Optgroup, a.Option, a.P, a.Rb, a.Rp, a.Rt, a.Rtc:
for _, except := range exceptions { for _, except := range exceptions {
if n.Data == except { if n.Data == except {
break loop break loop
@ -207,6 +209,27 @@ loop:
p.oe = p.oe[:i+1] p.oe = p.oe[:i+1]
} }
// generateAllImpliedEndTags pops nodes off the stack of open elements as long as
// the top node has a tag name of caption, colgroup, dd, div, dt, li, optgroup, option, p, rb,
// rp, rt, rtc, span, tbody, td, tfoot, th, thead or tr.
func (p *parser) generateAllImpliedEndTags() {
var i int
for i = len(p.oe) - 1; i >= 0; i-- {
n := p.oe[i]
if n.Type == ElementNode {
switch n.DataAtom {
// TODO: remove this divergence from the HTML5 spec
case a.Caption, a.Colgroup, a.Dd, a.Div, a.Dt, a.Li, a.Optgroup, a.Option, a.P, a.Rb,
a.Rp, a.Rt, a.Rtc, a.Span, a.Tbody, a.Td, a.Tfoot, a.Th, a.Thead, a.Tr:
continue
}
}
break
}
p.oe = p.oe[:i+1]
}
// addChild adds a child node n to the top element, and pushes n onto the stack // addChild adds a child node n to the top element, and pushes n onto the stack
// of open elements if it is an element node. // of open elements if it is an element node.
func (p *parser) addChild(n *Node) { func (p *parser) addChild(n *Node) {
@ -234,9 +257,9 @@ func (p *parser) shouldFosterParent() bool {
} }
// fosterParent adds a child node according to the foster parenting rules. // fosterParent adds a child node according to the foster parenting rules.
// Section 12.2.5.3, "foster parenting". // Section 12.2.6.1, "foster parenting".
func (p *parser) fosterParent(n *Node) { func (p *parser) fosterParent(n *Node) {
var table, parent, prev *Node var table, parent, prev, template *Node
var i int var i int
for i = len(p.oe) - 1; i >= 0; i-- { for i = len(p.oe) - 1; i >= 0; i-- {
if p.oe[i].DataAtom == a.Table { if p.oe[i].DataAtom == a.Table {
@ -245,6 +268,19 @@ func (p *parser) fosterParent(n *Node) {
} }
} }
var j int
for j = len(p.oe) - 1; j >= 0; j-- {
if p.oe[j].DataAtom == a.Template {
template = p.oe[j]
break
}
}
if template != nil && (table == nil || j < i) {
template.AppendChild(n)
return
}
if table == nil { if table == nil {
// The foster parent is the html element. // The foster parent is the html element.
parent = p.oe[0] parent = p.oe[0]
@ -304,7 +340,7 @@ func (p *parser) addElement() {
}) })
} }
// Section 12.2.3.3. // Section 12.2.4.3.
func (p *parser) addFormattingElement() { func (p *parser) addFormattingElement() {
tagAtom, attr := p.tok.DataAtom, p.tok.Attr tagAtom, attr := p.tok.DataAtom, p.tok.Attr
p.addElement() p.addElement()
@ -351,7 +387,7 @@ findIdenticalElements:
p.afe = append(p.afe, p.top()) p.afe = append(p.afe, p.top())
} }
// Section 12.2.3.3. // Section 12.2.4.3.
func (p *parser) clearActiveFormattingElements() { func (p *parser) clearActiveFormattingElements() {
for { for {
n := p.afe.pop() n := p.afe.pop()
@ -361,7 +397,7 @@ func (p *parser) clearActiveFormattingElements() {
} }
} }
// Section 12.2.3.3. // Section 12.2.4.3.
func (p *parser) reconstructActiveFormattingElements() { func (p *parser) reconstructActiveFormattingElements() {
n := p.afe.top() n := p.afe.top()
if n == nil { if n == nil {
@ -390,12 +426,12 @@ func (p *parser) reconstructActiveFormattingElements() {
} }
} }
// Section 12.2.4. // Section 12.2.5.
func (p *parser) acknowledgeSelfClosingTag() { func (p *parser) acknowledgeSelfClosingTag() {
p.hasSelfClosingToken = false p.hasSelfClosingToken = false
} }
// An insertion mode (section 12.2.3.1) is the state transition function from // An insertion mode (section 12.2.4.1) is the state transition function from
// a particular state in the HTML5 parser's state machine. It updates the // a particular state in the HTML5 parser's state machine. It updates the
// parser's fields depending on parser.tok (where ErrorToken means EOF). // parser's fields depending on parser.tok (where ErrorToken means EOF).
// It returns whether the token was consumed. // It returns whether the token was consumed.
@ -403,7 +439,7 @@ type insertionMode func(*parser) bool
// setOriginalIM sets the insertion mode to return to after completing a text or // setOriginalIM sets the insertion mode to return to after completing a text or
// inTableText insertion mode. // inTableText insertion mode.
// Section 12.2.3.1, "using the rules for". // Section 12.2.4.1, "using the rules for".
func (p *parser) setOriginalIM() { func (p *parser) setOriginalIM() {
if p.originalIM != nil { if p.originalIM != nil {
panic("html: bad parser state: originalIM was set twice") panic("html: bad parser state: originalIM was set twice")
@ -411,18 +447,38 @@ func (p *parser) setOriginalIM() {
p.originalIM = p.im p.originalIM = p.im
} }
// Section 12.2.3.1, "reset the insertion mode". // Section 12.2.4.1, "reset the insertion mode".
func (p *parser) resetInsertionMode() { func (p *parser) resetInsertionMode() {
for i := len(p.oe) - 1; i >= 0; i-- { for i := len(p.oe) - 1; i >= 0; i-- {
n := p.oe[i] n := p.oe[i]
if i == 0 && p.context != nil { last := i == 0
if last && p.context != nil {
n = p.context n = p.context
} }
switch n.DataAtom { switch n.DataAtom {
case a.Select: case a.Select:
if !last {
for ancestor, first := n, p.oe[0]; ancestor != first; {
if ancestor == first {
break
}
ancestor = p.oe[p.oe.index(ancestor)-1]
switch ancestor.DataAtom {
case a.Template:
p.im = inSelectIM
return
case a.Table:
p.im = inSelectInTableIM
return
}
}
}
p.im = inSelectIM p.im = inSelectIM
case a.Td, a.Th: case a.Td, a.Th:
// TODO: remove this divergence from the HTML5 spec.
//
// See https://bugs.chromium.org/p/chromium/issues/detail?id=829668
p.im = inCellIM p.im = inCellIM
case a.Tr: case a.Tr:
p.im = inRowIM p.im = inRowIM
@ -434,25 +490,37 @@ func (p *parser) resetInsertionMode() {
p.im = inColumnGroupIM p.im = inColumnGroupIM
case a.Table: case a.Table:
p.im = inTableIM p.im = inTableIM
case a.Template:
p.im = p.templateStack.top()
case a.Head: case a.Head:
p.im = inBodyIM // TODO: remove this divergence from the HTML5 spec.
//
// See https://bugs.chromium.org/p/chromium/issues/detail?id=829668
p.im = inHeadIM
case a.Body: case a.Body:
p.im = inBodyIM p.im = inBodyIM
case a.Frameset: case a.Frameset:
p.im = inFramesetIM p.im = inFramesetIM
case a.Html: case a.Html:
p.im = beforeHeadIM if p.head == nil {
p.im = beforeHeadIM
} else {
p.im = afterHeadIM
}
default: default:
if last {
p.im = inBodyIM
return
}
continue continue
} }
return return
} }
p.im = inBodyIM
} }
const whitespace = " \t\r\n\f" const whitespace = " \t\r\n\f"
// Section 12.2.5.4.1. // Section 12.2.6.4.1.
func initialIM(p *parser) bool { func initialIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -479,7 +547,7 @@ func initialIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.2. // Section 12.2.6.4.2.
func beforeHTMLIM(p *parser) bool { func beforeHTMLIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case DoctypeToken: case DoctypeToken:
@ -517,7 +585,7 @@ func beforeHTMLIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.3. // Section 12.2.6.4.3.
func beforeHeadIM(p *parser) bool { func beforeHeadIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -560,7 +628,7 @@ func beforeHeadIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.4. // Section 12.2.6.4.4.
func inHeadIM(p *parser) bool { func inHeadIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -590,19 +658,36 @@ func inHeadIM(p *parser) bool {
case a.Head: case a.Head:
// Ignore the token. // Ignore the token.
return true return true
case a.Template:
p.addElement()
p.afe = append(p.afe, &scopeMarker)
p.framesetOK = false
p.im = inTemplateIM
p.templateStack = append(p.templateStack, inTemplateIM)
return true
} }
case EndTagToken: case EndTagToken:
switch p.tok.DataAtom { switch p.tok.DataAtom {
case a.Head: case a.Head:
n := p.oe.pop() p.oe.pop()
if n.DataAtom != a.Head {
panic("html: bad parser state: <head> element not found, in the in-head insertion mode")
}
p.im = afterHeadIM p.im = afterHeadIM
return true return true
case a.Body, a.Html, a.Br: case a.Body, a.Html, a.Br:
p.parseImpliedToken(EndTagToken, a.Head, a.Head.String()) p.parseImpliedToken(EndTagToken, a.Head, a.Head.String())
return false return false
case a.Template:
if !p.oe.contains(a.Template) {
return true
}
p.generateAllImpliedEndTags()
if n := p.oe.top(); n.DataAtom != a.Template {
return true
}
p.popUntil(defaultScope, a.Template)
p.clearActiveFormattingElements()
p.templateStack.pop()
p.resetInsertionMode()
return true
default: default:
// Ignore the token. // Ignore the token.
return true return true
@ -622,7 +707,7 @@ func inHeadIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.6. // Section 12.2.6.4.6.
func afterHeadIM(p *parser) bool { func afterHeadIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -648,7 +733,7 @@ func afterHeadIM(p *parser) bool {
p.addElement() p.addElement()
p.im = inFramesetIM p.im = inFramesetIM
return true return true
case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Title: case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title:
p.oe = append(p.oe, p.head) p.oe = append(p.oe, p.head)
defer p.oe.remove(p.head) defer p.oe.remove(p.head)
return inHeadIM(p) return inHeadIM(p)
@ -660,6 +745,8 @@ func afterHeadIM(p *parser) bool {
switch p.tok.DataAtom { switch p.tok.DataAtom {
case a.Body, a.Html, a.Br: case a.Body, a.Html, a.Br:
// Drop down to creating an implied <body> tag. // Drop down to creating an implied <body> tag.
case a.Template:
return inHeadIM(p)
default: default:
// Ignore the token. // Ignore the token.
return true return true
@ -697,7 +784,7 @@ func copyAttributes(dst *Node, src Token) {
} }
} }
// Section 12.2.5.4.7. // Section 12.2.6.4.7.
func inBodyIM(p *parser) bool { func inBodyIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -727,10 +814,16 @@ func inBodyIM(p *parser) bool {
case StartTagToken: case StartTagToken:
switch p.tok.DataAtom { switch p.tok.DataAtom {
case a.Html: case a.Html:
if p.oe.contains(a.Template) {
return true
}
copyAttributes(p.oe[0], p.tok) copyAttributes(p.oe[0], p.tok)
case a.Base, a.Basefont, a.Bgsound, a.Command, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Title: case a.Base, a.Basefont, a.Bgsound, a.Command, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title:
return inHeadIM(p) return inHeadIM(p)
case a.Body: case a.Body:
if p.oe.contains(a.Template) {
return true
}
if len(p.oe) >= 2 { if len(p.oe) >= 2 {
body := p.oe[1] body := p.oe[1]
if body.Type == ElementNode && body.DataAtom == a.Body { if body.Type == ElementNode && body.DataAtom == a.Body {
@ -767,7 +860,7 @@ func inBodyIM(p *parser) bool {
// The newline, if any, will be dealt with by the TextToken case. // The newline, if any, will be dealt with by the TextToken case.
p.framesetOK = false p.framesetOK = false
case a.Form: case a.Form:
if p.form == nil { if p.oe.contains(a.Template) || p.form == nil {
p.popUntil(buttonScope, a.P) p.popUntil(buttonScope, a.P)
p.addElement() p.addElement()
p.form = p.top() p.form = p.top()
@ -952,11 +1045,16 @@ func inBodyIM(p *parser) bool {
} }
p.reconstructActiveFormattingElements() p.reconstructActiveFormattingElements()
p.addElement() p.addElement()
case a.Rp, a.Rt: case a.Rb, a.Rtc:
if p.elementInScope(defaultScope, a.Ruby) { if p.elementInScope(defaultScope, a.Ruby) {
p.generateImpliedEndTags() p.generateImpliedEndTags()
} }
p.addElement() p.addElement()
case a.Rp, a.Rt:
if p.elementInScope(defaultScope, a.Ruby) {
p.generateImpliedEndTags("rtc")
}
p.addElement()
case a.Math, a.Svg: case a.Math, a.Svg:
p.reconstructActiveFormattingElements() p.reconstructActiveFormattingElements()
if p.tok.DataAtom == a.Math { if p.tok.DataAtom == a.Math {
@ -972,7 +1070,13 @@ func inBodyIM(p *parser) bool {
p.acknowledgeSelfClosingTag() p.acknowledgeSelfClosingTag()
} }
return true return true
case a.Caption, a.Col, a.Colgroup, a.Frame, a.Head, a.Tbody, a.Td, a.Tfoot, a.Th, a.Thead, a.Tr: case a.Frame:
// TODO: remove this divergence from the HTML5 spec.
if p.oe.contains(a.Template) {
p.addElement()
return true
}
case a.Caption, a.Col, a.Colgroup, a.Head, a.Tbody, a.Td, a.Tfoot, a.Th, a.Thead, a.Tr:
// Ignore the token. // Ignore the token.
default: default:
p.reconstructActiveFormattingElements() p.reconstructActiveFormattingElements()
@ -993,15 +1097,28 @@ func inBodyIM(p *parser) bool {
case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Menu, a.Nav, a.Ol, a.Pre, a.Section, a.Summary, a.Ul: case a.Address, a.Article, a.Aside, a.Blockquote, a.Button, a.Center, a.Details, a.Dir, a.Div, a.Dl, a.Fieldset, a.Figcaption, a.Figure, a.Footer, a.Header, a.Hgroup, a.Listing, a.Menu, a.Nav, a.Ol, a.Pre, a.Section, a.Summary, a.Ul:
p.popUntil(defaultScope, p.tok.DataAtom) p.popUntil(defaultScope, p.tok.DataAtom)
case a.Form: case a.Form:
node := p.form if p.oe.contains(a.Template) {
p.form = nil if !p.oe.contains(a.Form) {
i := p.indexOfElementInScope(defaultScope, a.Form) // Ignore the token.
if node == nil || i == -1 || p.oe[i] != node { return true
// Ignore the token. }
return true p.generateImpliedEndTags()
if p.tok.DataAtom == a.Form {
// Ignore the token.
return true
}
p.popUntil(defaultScope, a.Form)
} else {
node := p.form
p.form = nil
i := p.indexOfElementInScope(defaultScope, a.Form)
if node == nil || i == -1 || p.oe[i] != node {
// Ignore the token.
return true
}
p.generateImpliedEndTags()
p.oe.remove(node)
} }
p.generateImpliedEndTags()
p.oe.remove(node)
case a.P: case a.P:
if !p.elementInScope(buttonScope, a.P) { if !p.elementInScope(buttonScope, a.P) {
p.parseImpliedToken(StartTagToken, a.P, a.P.String()) p.parseImpliedToken(StartTagToken, a.P, a.P.String())
@ -1022,6 +1139,8 @@ func inBodyIM(p *parser) bool {
case a.Br: case a.Br:
p.tok.Type = StartTagToken p.tok.Type = StartTagToken
return false return false
case a.Template:
return inHeadIM(p)
default: default:
p.inBodyEndTagOther(p.tok.DataAtom) p.inBodyEndTagOther(p.tok.DataAtom)
} }
@ -1030,6 +1149,21 @@ func inBodyIM(p *parser) bool {
Type: CommentNode, Type: CommentNode,
Data: p.tok.Data, Data: p.tok.Data,
}) })
case ErrorToken:
// TODO: remove this divergence from the HTML5 spec.
if len(p.templateStack) > 0 {
p.im = inTemplateIM
return false
} else {
for _, e := range p.oe {
switch e.DataAtom {
case a.Dd, a.Dt, a.Li, a.Optgroup, a.Option, a.P, a.Rb, a.Rp, a.Rt, a.Rtc, a.Tbody, a.Td, a.Tfoot, a.Th,
a.Thead, a.Tr, a.Body, a.Html:
default:
return true
}
}
}
} }
return true return true
@ -1135,6 +1269,12 @@ func (p *parser) inBodyEndTagFormatting(tagAtom a.Atom) {
switch commonAncestor.DataAtom { switch commonAncestor.DataAtom {
case a.Table, a.Tbody, a.Tfoot, a.Thead, a.Tr: case a.Table, a.Tbody, a.Tfoot, a.Thead, a.Tr:
p.fosterParent(lastNode) p.fosterParent(lastNode)
case a.Template:
// TODO: remove namespace checking
if commonAncestor.Namespace == "html" {
commonAncestor = commonAncestor.LastChild
}
fallthrough
default: default:
commonAncestor.AppendChild(lastNode) commonAncestor.AppendChild(lastNode)
} }
@ -1160,7 +1300,7 @@ func (p *parser) inBodyEndTagFormatting(tagAtom a.Atom) {
} }
// inBodyEndTagOther performs the "any other end tag" algorithm for inBodyIM. // inBodyEndTagOther performs the "any other end tag" algorithm for inBodyIM.
// "Any other end tag" handling from 12.2.5.5 The rules for parsing tokens in foreign content // "Any other end tag" handling from 12.2.6.5 The rules for parsing tokens in foreign content
// https://html.spec.whatwg.org/multipage/syntax.html#parsing-main-inforeign // https://html.spec.whatwg.org/multipage/syntax.html#parsing-main-inforeign
func (p *parser) inBodyEndTagOther(tagAtom a.Atom) { func (p *parser) inBodyEndTagOther(tagAtom a.Atom) {
for i := len(p.oe) - 1; i >= 0; i-- { for i := len(p.oe) - 1; i >= 0; i-- {
@ -1174,7 +1314,7 @@ func (p *parser) inBodyEndTagOther(tagAtom a.Atom) {
} }
} }
// Section 12.2.5.4.8. // Section 12.2.6.4.8.
func textIM(p *parser) bool { func textIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
@ -1203,7 +1343,7 @@ func textIM(p *parser) bool {
return p.tok.Type == EndTagToken return p.tok.Type == EndTagToken
} }
// Section 12.2.5.4.9. // Section 12.2.6.4.9.
func inTableIM(p *parser) bool { func inTableIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
@ -1249,7 +1389,7 @@ func inTableIM(p *parser) bool {
} }
// Ignore the token. // Ignore the token.
return true return true
case a.Style, a.Script: case a.Style, a.Script, a.Template:
return inHeadIM(p) return inHeadIM(p)
case a.Input: case a.Input:
for _, t := range p.tok.Attr { for _, t := range p.tok.Attr {
@ -1261,7 +1401,7 @@ func inTableIM(p *parser) bool {
} }
// Otherwise drop down to the default action. // Otherwise drop down to the default action.
case a.Form: case a.Form:
if p.form != nil { if p.oe.contains(a.Template) || p.form != nil {
// Ignore the token. // Ignore the token.
return true return true
} }
@ -1291,6 +1431,8 @@ func inTableIM(p *parser) bool {
case a.Body, a.Caption, a.Col, a.Colgroup, a.Html, a.Tbody, a.Td, a.Tfoot, a.Th, a.Thead, a.Tr: case a.Body, a.Caption, a.Col, a.Colgroup, a.Html, a.Tbody, a.Td, a.Tfoot, a.Th, a.Thead, a.Tr:
// Ignore the token. // Ignore the token.
return true return true
case a.Template:
return inHeadIM(p)
} }
case CommentToken: case CommentToken:
p.addChild(&Node{ p.addChild(&Node{
@ -1309,7 +1451,7 @@ func inTableIM(p *parser) bool {
return inBodyIM(p) return inBodyIM(p)
} }
// Section 12.2.5.4.11. // Section 12.2.6.4.11.
func inCaptionIM(p *parser) bool { func inCaptionIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken: case StartTagToken:
@ -1355,7 +1497,7 @@ func inCaptionIM(p *parser) bool {
return inBodyIM(p) return inBodyIM(p)
} }
// Section 12.2.5.4.12. // Section 12.2.6.4.12.
func inColumnGroupIM(p *parser) bool { func inColumnGroupIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -1386,11 +1528,13 @@ func inColumnGroupIM(p *parser) bool {
p.oe.pop() p.oe.pop()
p.acknowledgeSelfClosingTag() p.acknowledgeSelfClosingTag()
return true return true
case a.Template:
return inHeadIM(p)
} }
case EndTagToken: case EndTagToken:
switch p.tok.DataAtom { switch p.tok.DataAtom {
case a.Colgroup: case a.Colgroup:
if p.oe.top().DataAtom != a.Html { if p.oe.top().DataAtom == a.Colgroup {
p.oe.pop() p.oe.pop()
p.im = inTableIM p.im = inTableIM
} }
@ -1398,17 +1542,19 @@ func inColumnGroupIM(p *parser) bool {
case a.Col: case a.Col:
// Ignore the token. // Ignore the token.
return true return true
case a.Template:
return inHeadIM(p)
} }
} }
if p.oe.top().DataAtom != a.Html { if p.oe.top().DataAtom != a.Colgroup {
p.oe.pop() return true
p.im = inTableIM
return false
} }
return true p.oe.pop()
p.im = inTableIM
return false
} }
// Section 12.2.5.4.13. // Section 12.2.6.4.13.
func inTableBodyIM(p *parser) bool { func inTableBodyIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken: case StartTagToken:
@ -1460,7 +1606,7 @@ func inTableBodyIM(p *parser) bool {
return inTableIM(p) return inTableIM(p)
} }
// Section 12.2.5.4.14. // Section 12.2.6.4.14.
func inRowIM(p *parser) bool { func inRowIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken: case StartTagToken:
@ -1511,7 +1657,7 @@ func inRowIM(p *parser) bool {
return inTableIM(p) return inTableIM(p)
} }
// Section 12.2.5.4.15. // Section 12.2.6.4.15.
func inCellIM(p *parser) bool { func inCellIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken: case StartTagToken:
@ -1560,7 +1706,7 @@ func inCellIM(p *parser) bool {
return inBodyIM(p) return inBodyIM(p)
} }
// Section 12.2.5.4.16. // Section 12.2.6.4.16.
func inSelectIM(p *parser) bool { func inSelectIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
@ -1597,7 +1743,7 @@ func inSelectIM(p *parser) bool {
p.tokenizer.NextIsNotRawText() p.tokenizer.NextIsNotRawText()
// Ignore the token. // Ignore the token.
return true return true
case a.Script: case a.Script, a.Template:
return inHeadIM(p) return inHeadIM(p)
} }
case EndTagToken: case EndTagToken:
@ -1618,6 +1764,8 @@ func inSelectIM(p *parser) bool {
if p.popUntil(selectScope, a.Select) { if p.popUntil(selectScope, a.Select) {
p.resetInsertionMode() p.resetInsertionMode()
} }
case a.Template:
return inHeadIM(p)
} }
case CommentToken: case CommentToken:
p.addChild(&Node{ p.addChild(&Node{
@ -1632,7 +1780,7 @@ func inSelectIM(p *parser) bool {
return true return true
} }
// Section 12.2.5.4.17. // Section 12.2.6.4.17.
func inSelectInTableIM(p *parser) bool { func inSelectInTableIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case StartTagToken, EndTagToken: case StartTagToken, EndTagToken:
@ -1650,7 +1798,62 @@ func inSelectInTableIM(p *parser) bool {
return inSelectIM(p) return inSelectIM(p)
} }
// Section 12.2.5.4.18. // Section 12.2.6.4.18.
func inTemplateIM(p *parser) bool {
switch p.tok.Type {
case TextToken, CommentToken, DoctypeToken:
return inBodyIM(p)
case StartTagToken:
switch p.tok.DataAtom {
case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title:
return inHeadIM(p)
case a.Caption, a.Colgroup, a.Tbody, a.Tfoot, a.Thead:
p.templateStack.pop()
p.templateStack = append(p.templateStack, inTableIM)
p.im = inTableIM
return false
case a.Col:
p.templateStack.pop()
p.templateStack = append(p.templateStack, inColumnGroupIM)
p.im = inColumnGroupIM
return false
case a.Tr:
p.templateStack.pop()
p.templateStack = append(p.templateStack, inTableBodyIM)
p.im = inTableBodyIM
return false
case a.Td, a.Th:
p.templateStack.pop()
p.templateStack = append(p.templateStack, inRowIM)
p.im = inRowIM
return false
default:
p.templateStack.pop()
p.templateStack = append(p.templateStack, inBodyIM)
p.im = inBodyIM
return false
}
case EndTagToken:
switch p.tok.DataAtom {
case a.Template:
return inHeadIM(p)
default:
// Ignore the token.
return true
}
}
if !p.oe.contains(a.Template) {
// Ignore the token.
return true
}
p.popUntil(defaultScope, a.Template)
p.clearActiveFormattingElements()
p.templateStack.pop()
p.resetInsertionMode()
return false
}
// Section 12.2.6.4.19.
func afterBodyIM(p *parser) bool { func afterBodyIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
@ -1688,7 +1891,7 @@ func afterBodyIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.19. // Section 12.2.6.4.20.
func inFramesetIM(p *parser) bool { func inFramesetIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case CommentToken: case CommentToken:
@ -1720,6 +1923,11 @@ func inFramesetIM(p *parser) bool {
p.acknowledgeSelfClosingTag() p.acknowledgeSelfClosingTag()
case a.Noframes: case a.Noframes:
return inHeadIM(p) return inHeadIM(p)
case a.Template:
// TODO: remove this divergence from the HTML5 spec.
//
// See https://bugs.chromium.org/p/chromium/issues/detail?id=829668
return inTemplateIM(p)
} }
case EndTagToken: case EndTagToken:
switch p.tok.DataAtom { switch p.tok.DataAtom {
@ -1738,7 +1946,7 @@ func inFramesetIM(p *parser) bool {
return true return true
} }
// Section 12.2.5.4.20. // Section 12.2.6.4.21.
func afterFramesetIM(p *parser) bool { func afterFramesetIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case CommentToken: case CommentToken:
@ -1777,7 +1985,7 @@ func afterFramesetIM(p *parser) bool {
return true return true
} }
// Section 12.2.5.4.21. // Section 12.2.6.4.22.
func afterAfterBodyIM(p *parser) bool { func afterAfterBodyIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
@ -1806,7 +2014,7 @@ func afterAfterBodyIM(p *parser) bool {
return false return false
} }
// Section 12.2.5.4.22. // Section 12.2.6.4.23.
func afterAfterFramesetIM(p *parser) bool { func afterAfterFramesetIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case CommentToken: case CommentToken:
@ -1844,7 +2052,7 @@ func afterAfterFramesetIM(p *parser) bool {
const whitespaceOrNUL = whitespace + "\x00" const whitespaceOrNUL = whitespace + "\x00"
// Section 12.2.5.5. // Section 12.2.6.5
func parseForeignContent(p *parser) bool { func parseForeignContent(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case TextToken: case TextToken:
@ -1924,7 +2132,7 @@ func parseForeignContent(p *parser) bool {
return true return true
} }
// Section 12.2.5. // Section 12.2.6.
func (p *parser) inForeignContent() bool { func (p *parser) inForeignContent() bool {
if len(p.oe) == 0 { if len(p.oe) == 0 {
return false return false
@ -2064,6 +2272,9 @@ func ParseFragment(r io.Reader, context *Node) ([]*Node, error) {
} }
p.doc.AppendChild(root) p.doc.AppendChild(root)
p.oe = nodeStack{root} p.oe = nodeStack{root}
if context != nil && context.DataAtom == a.Template {
p.templateStack = append(p.templateStack, inTemplateIM)
}
p.resetInsertionMode() p.resetInsertionMode()
for n := context; n != nil; n = n.Parent { for n := context; n != nil; n = n.Parent {

View File

@ -125,6 +125,7 @@ func (a sortedAttributes) Swap(i, j int) {
func dumpLevel(w io.Writer, n *Node, level int) error { func dumpLevel(w io.Writer, n *Node, level int) error {
dumpIndent(w, level) dumpIndent(w, level)
level++
switch n.Type { switch n.Type {
case ErrorNode: case ErrorNode:
return errors.New("unexpected ErrorNode") return errors.New("unexpected ErrorNode")
@ -140,13 +141,19 @@ func dumpLevel(w io.Writer, n *Node, level int) error {
sort.Sort(attr) sort.Sort(attr)
for _, a := range attr { for _, a := range attr {
io.WriteString(w, "\n") io.WriteString(w, "\n")
dumpIndent(w, level+1) dumpIndent(w, level)
if a.Namespace != "" { if a.Namespace != "" {
fmt.Fprintf(w, `%s %s="%s"`, a.Namespace, a.Key, a.Val) fmt.Fprintf(w, `%s %s="%s"`, a.Namespace, a.Key, a.Val)
} else { } else {
fmt.Fprintf(w, `%s="%s"`, a.Key, a.Val) fmt.Fprintf(w, `%s="%s"`, a.Key, a.Val)
} }
} }
if n.Namespace == "" && n.DataAtom == atom.Template {
io.WriteString(w, "\n")
dumpIndent(w, level)
level++
io.WriteString(w, "content")
}
case TextNode: case TextNode:
fmt.Fprintf(w, `"%s"`, n.Data) fmt.Fprintf(w, `"%s"`, n.Data)
case CommentNode: case CommentNode:
@ -176,7 +183,7 @@ func dumpLevel(w io.Writer, n *Node, level int) error {
} }
io.WriteString(w, "\n") io.WriteString(w, "\n")
for c := n.FirstChild; c != nil; c = c.NextSibling { for c := n.FirstChild; c != nil; c = c.NextSibling {
if err := dumpLevel(w, c, level+1); err != nil { if err := dumpLevel(w, c, level); err != nil {
return err return err
} }
} }
@ -373,6 +380,11 @@ func TestNodeConsistency(t *testing.T) {
} }
} }
func TestParseFragmentWithNilContext(t *testing.T) {
// This shouldn't panic.
ParseFragment(strings.NewReader("<p>hello</p>"), nil)
}
func BenchmarkParser(b *testing.B) { func BenchmarkParser(b *testing.B) {
buf, err := ioutil.ReadFile("testdata/go1.html") buf, err := ioutil.ReadFile("testdata/go1.html")
if err != nil { if err != nil {

298
vendor/golang.org/x/net/html/testdata/webkit/ruby.dat generated vendored Normal file
View File

@ -0,0 +1,298 @@
#data
<html><ruby>a<rb>b<rb></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rb>
| "b"
| <rb>
#data
<html><ruby>a<rb>b<rt></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rb>
| "b"
| <rt>
#data
<html><ruby>a<rb>b<rtc></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rb>
| "b"
| <rtc>
#data
<html><ruby>a<rb>b<rp></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rb>
| "b"
| <rp>
#data
<html><ruby>a<rb>b<span></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rb>
| "b"
| <span>
#data
<html><ruby>a<rt>b<rb></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rt>
| "b"
| <rb>
#data
<html><ruby>a<rt>b<rt></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rt>
| "b"
| <rt>
#data
<html><ruby>a<rt>b<rtc></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rt>
| "b"
| <rtc>
#data
<html><ruby>a<rt>b<rp></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rt>
| "b"
| <rp>
#data
<html><ruby>a<rt>b<span></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rt>
| "b"
| <span>
#data
<html><ruby>a<rtc>b<rb></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rtc>
| "b"
| <rb>
#data
<html><ruby>a<rtc>b<rt>c<rt>d</ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rtc>
| "b"
| <rt>
| "c"
| <rt>
| "d"
#data
<html><ruby>a<rtc>b<rtc></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rtc>
| "b"
| <rtc>
#data
<html><ruby>a<rtc>b<rp></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rtc>
| "b"
| <rp>
#data
<html><ruby>a<rtc>b<span></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rtc>
| "b"
| <span>
#data
<html><ruby>a<rp>b<rb></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rp>
| "b"
| <rb>
#data
<html><ruby>a<rp>b<rt></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rp>
| "b"
| <rt>
#data
<html><ruby>a<rp>b<rtc></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rp>
| "b"
| <rtc>
#data
<html><ruby>a<rp>b<rp></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rp>
| "b"
| <rp>
#data
<html><ruby>a<rp>b<span></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| "a"
| <rp>
| "b"
| <span>
#data
<html><ruby><rtc><ruby>a<rb>b<rt></ruby></ruby></html>
#errors
(1,6): expected-doctype-but-got-start-tag
#document
| <html>
| <head>
| <body>
| <ruby>
| <rtc>
| <ruby>
| "a"
| <rb>
| "b"
| <rt>

1117
vendor/golang.org/x/net/html/testdata/webkit/template.dat generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1161,8 +1161,8 @@ func (z *Tokenizer) TagAttr() (key, val []byte, moreAttr bool) {
return nil, nil, false return nil, nil, false
} }
// Token returns the next Token. The result's Data and Attr values remain valid // Token returns the current Token. The result's Data and Attr values remain
// after subsequent Next calls. // valid after subsequent Next calls.
func (z *Tokenizer) Token() Token { func (z *Tokenizer) Token() Token {
t := Token{Type: z.tt} t := Token{Type: z.tt}
switch z.tt { switch z.tt {

50
vendor/golang.org/x/net/http/httpguts/guts.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpguts provides functions implementing various details
// of the HTTP specification.
//
// This package is shared by the standard library (which vendors it)
// and x/net/http2. It comes with no API stability promise.
package httpguts
import (
"net/textproto"
"strings"
)
// ValidTrailerHeader reports whether name is a valid header field name to appear
// in trailers.
// See RFC 7230, Section 4.1.2
func ValidTrailerHeader(name string) bool {
name = textproto.CanonicalMIMEHeaderKey(name)
if strings.HasPrefix(name, "If-") || badTrailer[name] {
return false
}
return true
}
var badTrailer = map[string]bool{
"Authorization": true,
"Cache-Control": true,
"Connection": true,
"Content-Encoding": true,
"Content-Length": true,
"Content-Range": true,
"Content-Type": true,
"Expect": true,
"Host": true,
"Keep-Alive": true,
"Max-Forwards": true,
"Pragma": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
"Range": true,
"Realm": true,
"Te": true,
"Trailer": true,
"Transfer-Encoding": true,
"Www-Authenticate": true,
}

View File

@ -0,0 +1,7 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpproxy
var ExportUseProxy = (*Config).useProxy

13
vendor/golang.org/x/net/http/httpproxy/go19_test.go generated vendored Normal file
View File

@ -0,0 +1,13 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.9
package httpproxy_test
import "testing"
func init() {
setHelper = func(t *testing.T) { t.Helper() }
}

239
vendor/golang.org/x/net/http/httpproxy/proxy.go generated vendored Normal file
View File

@ -0,0 +1,239 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpproxy provides support for HTTP proxy determination
// based on environment variables, as provided by net/http's
// ProxyFromEnvironment function.
//
// The API is not subject to the Go 1 compatibility promise and may change at
// any time.
package httpproxy
import (
"errors"
"fmt"
"net"
"net/url"
"os"
"strings"
"unicode/utf8"
"golang.org/x/net/idna"
)
// Config holds configuration for HTTP proxy settings. See
// FromEnvironment for details.
type Config struct {
// HTTPProxy represents the value of the HTTP_PROXY or
// http_proxy environment variable. It will be used as the proxy
// URL for HTTP requests and HTTPS requests unless overridden by
// HTTPSProxy or NoProxy.
HTTPProxy string
// HTTPSProxy represents the HTTPS_PROXY or https_proxy
// environment variable. It will be used as the proxy URL for
// HTTPS requests unless overridden by NoProxy.
HTTPSProxy string
// NoProxy represents the NO_PROXY or no_proxy environment
// variable. It specifies URLs that should be excluded from
// proxying as a comma-separated list of domain names or a
// single asterisk (*) to indicate that no proxying should be
// done. A domain name matches that name and all subdomains. A
// domain name with a leading "." matches subdomains only. For
// example "foo.com" matches "foo.com" and "bar.foo.com";
// ".y.com" matches "x.y.com" but not "y.com".
NoProxy string
// CGI holds whether the current process is running
// as a CGI handler (FromEnvironment infers this from the
// presence of a REQUEST_METHOD environment variable).
// When this is set, ProxyForURL will return an error
// when HTTPProxy applies, because a client could be
// setting HTTP_PROXY maliciously. See https://golang.org/s/cgihttpproxy.
CGI bool
}
// FromEnvironment returns a Config instance populated from the
// environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the
// lowercase versions thereof). HTTPS_PROXY takes precedence over
// HTTP_PROXY for https requests.
//
// The environment values may be either a complete URL or a
// "host[:port]", in which case the "http" scheme is assumed. An error
// is returned if the value is a different form.
func FromEnvironment() *Config {
return &Config{
HTTPProxy: getEnvAny("HTTP_PROXY", "http_proxy"),
HTTPSProxy: getEnvAny("HTTPS_PROXY", "https_proxy"),
NoProxy: getEnvAny("NO_PROXY", "no_proxy"),
CGI: os.Getenv("REQUEST_METHOD") != "",
}
}
func getEnvAny(names ...string) string {
for _, n := range names {
if val := os.Getenv(n); val != "" {
return val
}
}
return ""
}
// ProxyFunc returns a function that determines the proxy URL to use for
// a given request URL. Changing the contents of cfg will not affect
// proxy functions created earlier.
//
// A nil URL and nil error are returned if no proxy is defined in the
// environment, or a proxy should not be used for the given request, as
// defined by NO_PROXY.
//
// As a special case, if req.URL.Host is "localhost" (with or without a
// port number), then a nil URL and nil error will be returned.
func (cfg *Config) ProxyFunc() func(reqURL *url.URL) (*url.URL, error) {
// Prevent Config changes from affecting the function calculation.
// TODO Preprocess proxy settings for more efficient evaluation.
cfg1 := *cfg
return cfg1.proxyForURL
}
func (cfg *Config) proxyForURL(reqURL *url.URL) (*url.URL, error) {
var proxy string
if reqURL.Scheme == "https" {
proxy = cfg.HTTPSProxy
}
if proxy == "" {
proxy = cfg.HTTPProxy
if proxy != "" && cfg.CGI {
return nil, errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")
}
}
if proxy == "" {
return nil, nil
}
if !cfg.useProxy(canonicalAddr(reqURL)) {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
if err != nil ||
(proxyURL.Scheme != "http" &&
proxyURL.Scheme != "https" &&
proxyURL.Scheme != "socks5") {
// proxy was bogus. Try prepending "http://" to it and
// see if that parses correctly. If not, we fall
// through and complain about the original one.
if proxyURL, err := url.Parse("http://" + proxy); err == nil {
return proxyURL, nil
}
}
if err != nil {
return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err)
}
return proxyURL, nil
}
// useProxy reports whether requests to addr should use a proxy,
// according to the NO_PROXY or no_proxy environment variable.
// addr is always a canonicalAddr with a host and port.
func (cfg *Config) useProxy(addr string) bool {
if len(addr) == 0 {
return true
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return false
}
if host == "localhost" {
return false
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return false
}
}
noProxy := cfg.NoProxy
if noProxy == "*" {
return false
}
addr = strings.ToLower(strings.TrimSpace(addr))
if hasPort(addr) {
addr = addr[:strings.LastIndex(addr, ":")]
}
for _, p := range strings.Split(noProxy, ",") {
p = strings.ToLower(strings.TrimSpace(p))
if len(p) == 0 {
continue
}
if hasPort(p) {
p = p[:strings.LastIndex(p, ":")]
}
if addr == p {
return false
}
if len(p) == 0 {
// There is no host part, likely the entry is malformed; ignore.
continue
}
if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) {
// no_proxy ".foo.com" matches "bar.foo.com" or "foo.com"
return false
}
if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' {
// no_proxy "foo.com" matches "bar.foo.com"
return false
}
}
return true
}
var portMap = map[string]string{
"http": "80",
"https": "443",
"socks5": "1080",
}
// canonicalAddr returns url.Host but always with a ":port" suffix
func canonicalAddr(url *url.URL) string {
addr := url.Hostname()
if v, err := idnaASCII(addr); err == nil {
addr = v
}
port := url.Port()
if port == "" {
port = portMap[url.Scheme]
}
return net.JoinHostPort(addr, port)
}
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
// return true if the string includes a port.
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
func idnaASCII(v string) (string, error) {
// TODO: Consider removing this check after verifying performance is okay.
// Right now punycode verification, length checks, context checks, and the
// permissible character tests are all omitted. It also prevents the ToASCII
// call from salvaging an invalid IDN, when possible. As a result it may be
// possible to have two IDNs that appear identical to the user where the
// ASCII-only version causes an error downstream whereas the non-ASCII
// version does not.
// Note that for correct ASCII IDNs ToASCII will only do considerably more
// work, but it will not cause an allocation.
if isASCII(v) {
return v, nil
}
return idna.Lookup.ToASCII(v)
}
func isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
return false
}
}
return true
}

301
vendor/golang.org/x/net/http/httpproxy/proxy_test.go generated vendored Normal file
View File

@ -0,0 +1,301 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpproxy_test
import (
"bytes"
"errors"
"fmt"
"net/url"
"os"
"strings"
"testing"
"golang.org/x/net/http/httpproxy"
)
// setHelper calls t.Helper() for Go 1.9+ (see go19_test.go) and does nothing otherwise.
var setHelper = func(t *testing.T) {}
type proxyForURLTest struct {
cfg httpproxy.Config
req string // URL to fetch; blank means "http://example.com"
want string
wanterr error
}
func (t proxyForURLTest) String() string {
var buf bytes.Buffer
space := func() {
if buf.Len() > 0 {
buf.WriteByte(' ')
}
}
if t.cfg.HTTPProxy != "" {
fmt.Fprintf(&buf, "http_proxy=%q", t.cfg.HTTPProxy)
}
if t.cfg.HTTPSProxy != "" {
space()
fmt.Fprintf(&buf, "https_proxy=%q", t.cfg.HTTPSProxy)
}
if t.cfg.NoProxy != "" {
space()
fmt.Fprintf(&buf, "no_proxy=%q", t.cfg.NoProxy)
}
req := "http://example.com"
if t.req != "" {
req = t.req
}
space()
fmt.Fprintf(&buf, "req=%q", req)
return strings.TrimSpace(buf.String())
}
var proxyForURLTests = []proxyForURLTest{{
cfg: httpproxy.Config{
HTTPProxy: "127.0.0.1:8080",
},
want: "http://127.0.0.1:8080",
}, {
cfg: httpproxy.Config{
HTTPProxy: "cache.corp.example.com:1234",
},
want: "http://cache.corp.example.com:1234",
}, {
cfg: httpproxy.Config{
HTTPProxy: "cache.corp.example.com",
},
want: "http://cache.corp.example.com",
}, {
cfg: httpproxy.Config{
HTTPProxy: "https://cache.corp.example.com",
},
want: "https://cache.corp.example.com",
}, {
cfg: httpproxy.Config{
HTTPProxy: "http://127.0.0.1:8080",
},
want: "http://127.0.0.1:8080",
}, {
cfg: httpproxy.Config{
HTTPProxy: "https://127.0.0.1:8080",
},
want: "https://127.0.0.1:8080",
}, {
cfg: httpproxy.Config{
HTTPProxy: "socks5://127.0.0.1",
},
want: "socks5://127.0.0.1",
}, {
// Don't use secure for http
cfg: httpproxy.Config{
HTTPProxy: "http.proxy.tld",
HTTPSProxy: "secure.proxy.tld",
},
req: "http://insecure.tld/",
want: "http://http.proxy.tld",
}, {
// Use secure for https.
cfg: httpproxy.Config{
HTTPProxy: "http.proxy.tld",
HTTPSProxy: "secure.proxy.tld",
},
req: "https://secure.tld/",
want: "http://secure.proxy.tld",
}, {
cfg: httpproxy.Config{
HTTPProxy: "http.proxy.tld",
HTTPSProxy: "https://secure.proxy.tld",
},
req: "https://secure.tld/",
want: "https://secure.proxy.tld",
}, {
// Issue 16405: don't use HTTP_PROXY in a CGI environment,
// where HTTP_PROXY can be attacker-controlled.
cfg: httpproxy.Config{
HTTPProxy: "http://10.1.2.3:8080",
CGI: true,
},
want: "<nil>",
wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy"),
}, {
// HTTPS proxy is still used even in CGI environment.
// (perhaps dubious but it's the historical behaviour).
cfg: httpproxy.Config{
HTTPSProxy: "https://secure.proxy.tld",
CGI: true,
},
req: "https://secure.tld/",
want: "https://secure.proxy.tld",
}, {
want: "<nil>",
}, {
cfg: httpproxy.Config{
NoProxy: "example.com",
HTTPProxy: "proxy",
},
req: "http://example.com/",
want: "<nil>",
}, {
cfg: httpproxy.Config{
NoProxy: ".example.com",
HTTPProxy: "proxy",
},
req: "http://example.com/",
want: "<nil>",
}, {
cfg: httpproxy.Config{
NoProxy: "ample.com",
HTTPProxy: "proxy",
},
req: "http://example.com/",
want: "http://proxy",
}, {
cfg: httpproxy.Config{
NoProxy: "example.com",
HTTPProxy: "proxy",
},
req: "http://foo.example.com/",
want: "<nil>",
}, {
cfg: httpproxy.Config{
NoProxy: ".foo.com",
HTTPProxy: "proxy",
},
req: "http://example.com/",
want: "http://proxy",
}}
func testProxyForURL(t *testing.T, tt proxyForURLTest) {
setHelper(t)
reqURLStr := tt.req
if reqURLStr == "" {
reqURLStr = "http://example.com"
}
reqURL, err := url.Parse(reqURLStr)
if err != nil {
t.Errorf("invalid URL %q", reqURLStr)
return
}
cfg := tt.cfg
proxyForURL := cfg.ProxyFunc()
url, err := proxyForURL(reqURL)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
t.Errorf("%v: got error = %q, want %q", tt, g, e)
return
}
if got := fmt.Sprintf("%s", url); got != tt.want {
t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
}
// Check that changing the Config doesn't change the results
// of the functuon.
cfg = httpproxy.Config{}
url, err = proxyForURL(reqURL)
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
t.Errorf("(after mutating config) %v: got error = %q, want %q", tt, g, e)
return
}
if got := fmt.Sprintf("%s", url); got != tt.want {
t.Errorf("(after mutating config) %v: got URL = %q, want %q", tt, url, tt.want)
}
}
func TestProxyForURL(t *testing.T) {
for _, tt := range proxyForURLTests {
testProxyForURL(t, tt)
}
}
func TestFromEnvironment(t *testing.T) {
os.Setenv("HTTP_PROXY", "httpproxy")
os.Setenv("HTTPS_PROXY", "httpsproxy")
os.Setenv("NO_PROXY", "noproxy")
os.Setenv("REQUEST_METHOD", "")
got := httpproxy.FromEnvironment()
want := httpproxy.Config{
HTTPProxy: "httpproxy",
HTTPSProxy: "httpsproxy",
NoProxy: "noproxy",
}
if *got != want {
t.Errorf("unexpected proxy config, got %#v want %#v", got, want)
}
}
func TestFromEnvironmentWithRequestMethod(t *testing.T) {
os.Setenv("HTTP_PROXY", "httpproxy")
os.Setenv("HTTPS_PROXY", "httpsproxy")
os.Setenv("NO_PROXY", "noproxy")
os.Setenv("REQUEST_METHOD", "PUT")
got := httpproxy.FromEnvironment()
want := httpproxy.Config{
HTTPProxy: "httpproxy",
HTTPSProxy: "httpsproxy",
NoProxy: "noproxy",
CGI: true,
}
if *got != want {
t.Errorf("unexpected proxy config, got %#v want %#v", got, want)
}
}
func TestFromEnvironmentLowerCase(t *testing.T) {
os.Setenv("http_proxy", "httpproxy")
os.Setenv("https_proxy", "httpsproxy")
os.Setenv("no_proxy", "noproxy")
os.Setenv("REQUEST_METHOD", "")
got := httpproxy.FromEnvironment()
want := httpproxy.Config{
HTTPProxy: "httpproxy",
HTTPSProxy: "httpsproxy",
NoProxy: "noproxy",
}
if *got != want {
t.Errorf("unexpected proxy config, got %#v want %#v", got, want)
}
}
var UseProxyTests = []struct {
host string
match bool
}{
// Never proxy localhost:
{"localhost", false},
{"127.0.0.1", false},
{"127.0.0.2", false},
{"[::1]", false},
{"[::2]", true}, // not a loopback address
{"barbaz.net", false}, // match as .barbaz.net
{"foobar.com", false}, // have a port but match
{"foofoobar.com", true}, // not match as a part of foobar.com
{"baz.com", true}, // not match as a part of barbaz.com
{"localhost.net", true}, // not match as suffix of address
{"local.localhost", true}, // not match as prefix as address
{"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
{"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com"
}
func TestUseProxy(t *testing.T) {
cfg := &httpproxy.Config{
NoProxy: "foobar.com, .barbaz.net",
}
for _, test := range UseProxyTests {
if httpproxy.ExportUseProxy(cfg, test.host+":80") != test.match {
t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
}
}
}
func TestInvalidNoProxy(t *testing.T) {
cfg := &httpproxy.Config{
NoProxy: ":1",
}
ok := httpproxy.ExportUseProxy(cfg, "example.com:80") // should not panic
if !ok {
t.Errorf("useProxy unexpected return; got false; want true")
}
}

View File

@ -5,7 +5,7 @@
package http2 package http2
// A list of the possible cipher suite ids. Taken from // A list of the possible cipher suite ids. Taken from
// http://www.iana.org/assignments/tls-parameters/tls-parameters.txt // https://www.iana.org/assignments/tls-parameters/tls-parameters.txt
const ( const (
cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000

View File

@ -9,7 +9,7 @@ import "testing"
func TestIsBadCipherBad(t *testing.T) { func TestIsBadCipherBad(t *testing.T) {
for _, c := range badCiphers { for _, c := range badCiphers {
if !isBadCipher(c) { if !isBadCipher(c) {
t.Errorf("Wrong result for isBadCipher(%d), want true") t.Errorf("Wrong result for isBadCipher(%d), want true", c)
} }
} }
} }

View File

@ -73,7 +73,7 @@ type noDialH2RoundTripper struct{ t *Transport }
func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.t.RoundTrip(req) res, err := rt.t.RoundTrip(req)
if err == ErrNoCachedConn { if isNoCachedConnError(err) {
return nil, http.ErrSkipAltProtocol return nil, http.ErrSkipAltProtocol
} }
return res, err return res, err

View File

@ -87,13 +87,16 @@ type goAwayFlowError struct{}
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" } func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
// connErrorReason wraps a ConnectionError with an informative error about why it occurs. // connError represents an HTTP/2 ConnectionError error code, along
// with a string (for debugging) explaining why.
//
// Errors of this type are only returned by the frame parser functions // Errors of this type are only returned by the frame parser functions
// and converted into ConnectionError(ErrCodeProtocol). // and converted into ConnectionError(Code), after stashing away
// the Reason into the Framer's errDetail field, accessible via
// the (*Framer).ErrorDetail method.
type connError struct { type connError struct {
Code ErrCode Code ErrCode // the ConnectionError error code
Reason string Reason string // additional reason
} }
func (e connError) Error() string { func (e connError) Error() string {

View File

@ -52,3 +52,5 @@ func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
func reqBodyIsNoBody(body io.ReadCloser) bool { func reqBodyIsNoBody(body io.ReadCloser) bool {
return body == http.NoBody return body == http.NoBody
} }
func go18httpNoBody() io.ReadCloser { return http.NoBody } // for tests only

View File

@ -46,7 +46,6 @@ func TestServerGracefulShutdown(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"x-foo", "bar"}, {"x-foo", "bar"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {

View File

@ -3,3 +3,4 @@ h2demo.linux
client-id.dat client-id.dat
client-secret.dat client-secret.dat
token.dat token.dat
ca-certificates.crt

11
vendor/golang.org/x/net/http2/h2demo/Dockerfile generated vendored Normal file
View File

@ -0,0 +1,11 @@
# Copyright 2018 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
FROM scratch
LABEL maintainer "golang-dev@googlegroups.com"
COPY ca-certificates.crt /etc/ssl/certs/
COPY h2demo /
ENTRYPOINT ["/h2demo", "-prod"]

134
vendor/golang.org/x/net/http2/h2demo/Dockerfile.0 generated vendored Normal file
View File

@ -0,0 +1,134 @@
# Copyright 2018 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
FROM golang:1.9
LABEL maintainer "golang-dev@googlegroups.com"
ENV CGO_ENABLED=0
# BEGIN deps (run `make update-deps` to update)
# Repo cloud.google.com/go at 1d0c2da (2018-01-30)
ENV REV=1d0c2da40456a9b47f5376165f275424acc15c09
RUN go get -d cloud.google.com/go/compute/metadata `#and 6 other pkgs` &&\
(cd /go/src/cloud.google.com/go && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo github.com/golang/protobuf at 9255415 (2018-01-25)
ENV REV=925541529c1fa6821df4e44ce2723319eb2be768
RUN go get -d github.com/golang/protobuf/proto `#and 6 other pkgs` &&\
(cd /go/src/github.com/golang/protobuf && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo github.com/googleapis/gax-go at 317e000 (2017-09-15)
ENV REV=317e0006254c44a0ac427cc52a0e083ff0b9622f
RUN go get -d github.com/googleapis/gax-go &&\
(cd /go/src/github.com/googleapis/gax-go && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo go4.org at 034d17a (2017-05-25)
ENV REV=034d17a462f7b2dcd1a4a73553ec5357ff6e6c6e
RUN go get -d go4.org/syncutil/singleflight &&\
(cd /go/src/go4.org && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo golang.org/x/build at 8aa9ee0 (2018-02-01)
ENV REV=8aa9ee0e557fd49c14113e5ba106e13a5b455460
RUN go get -d golang.org/x/build/autocertcache &&\
(cd /go/src/golang.org/x/build && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo golang.org/x/crypto at 1875d0a (2018-01-27)
ENV REV=1875d0a70c90e57f11972aefd42276df65e895b9
RUN go get -d golang.org/x/crypto/acme `#and 2 other pkgs` &&\
(cd /go/src/golang.org/x/crypto && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo golang.org/x/oauth2 at 30785a2 (2018-01-04)
ENV REV=30785a2c434e431ef7c507b54617d6a951d5f2b4
RUN go get -d golang.org/x/oauth2 `#and 5 other pkgs` &&\
(cd /go/src/golang.org/x/oauth2 && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo golang.org/x/text at e19ae14 (2017-12-27)
ENV REV=e19ae1496984b1c655b8044a65c0300a3c878dd3
RUN go get -d golang.org/x/text/secure/bidirule `#and 4 other pkgs` &&\
(cd /go/src/golang.org/x/text && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo google.golang.org/api at 7d0e2d3 (2018-01-30)
ENV REV=7d0e2d350555821bef5a5b8aecf0d12cc1def633
RUN go get -d google.golang.org/api/gensupport `#and 9 other pkgs` &&\
(cd /go/src/google.golang.org/api && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo google.golang.org/genproto at 4eb30f4 (2018-01-25)
ENV REV=4eb30f4778eed4c258ba66527a0d4f9ec8a36c45
RUN go get -d google.golang.org/genproto/googleapis/api/annotations `#and 3 other pkgs` &&\
(cd /go/src/google.golang.org/genproto && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Repo google.golang.org/grpc at 0bd008f (2018-01-25)
ENV REV=0bd008f5fadb62d228f12b18d016709e8139a7af
RUN go get -d google.golang.org/grpc `#and 23 other pkgs` &&\
(cd /go/src/google.golang.org/grpc && (git cat-file -t $REV 2>/dev/null || git fetch -q origin $REV) && git reset --hard $REV)
# Optimization to speed up iterative development, not necessary for correctness:
RUN go install cloud.google.com/go/compute/metadata \
cloud.google.com/go/iam \
cloud.google.com/go/internal \
cloud.google.com/go/internal/optional \
cloud.google.com/go/internal/version \
cloud.google.com/go/storage \
github.com/golang/protobuf/proto \
github.com/golang/protobuf/protoc-gen-go/descriptor \
github.com/golang/protobuf/ptypes \
github.com/golang/protobuf/ptypes/any \
github.com/golang/protobuf/ptypes/duration \
github.com/golang/protobuf/ptypes/timestamp \
github.com/googleapis/gax-go \
go4.org/syncutil/singleflight \
golang.org/x/build/autocertcache \
golang.org/x/crypto/acme \
golang.org/x/crypto/acme/autocert \
golang.org/x/oauth2 \
golang.org/x/oauth2/google \
golang.org/x/oauth2/internal \
golang.org/x/oauth2/jws \
golang.org/x/oauth2/jwt \
golang.org/x/text/secure/bidirule \
golang.org/x/text/transform \
golang.org/x/text/unicode/bidi \
golang.org/x/text/unicode/norm \
google.golang.org/api/gensupport \
google.golang.org/api/googleapi \
google.golang.org/api/googleapi/internal/uritemplates \
google.golang.org/api/googleapi/transport \
google.golang.org/api/internal \
google.golang.org/api/iterator \
google.golang.org/api/option \
google.golang.org/api/storage/v1 \
google.golang.org/api/transport/http \
google.golang.org/genproto/googleapis/api/annotations \
google.golang.org/genproto/googleapis/iam/v1 \
google.golang.org/genproto/googleapis/rpc/status \
google.golang.org/grpc \
google.golang.org/grpc/balancer \
google.golang.org/grpc/balancer/base \
google.golang.org/grpc/balancer/roundrobin \
google.golang.org/grpc/codes \
google.golang.org/grpc/connectivity \
google.golang.org/grpc/credentials \
google.golang.org/grpc/encoding \
google.golang.org/grpc/encoding/proto \
google.golang.org/grpc/grpclb/grpc_lb_v1/messages \
google.golang.org/grpc/grpclog \
google.golang.org/grpc/internal \
google.golang.org/grpc/keepalive \
google.golang.org/grpc/metadata \
google.golang.org/grpc/naming \
google.golang.org/grpc/peer \
google.golang.org/grpc/resolver \
google.golang.org/grpc/resolver/dns \
google.golang.org/grpc/resolver/passthrough \
google.golang.org/grpc/stats \
google.golang.org/grpc/status \
google.golang.org/grpc/tap \
google.golang.org/grpc/transport
# END deps
COPY . /go/src/golang.org/x/net/
RUN go install -tags "h2demo netgo" -ldflags "-linkmode=external -extldflags '-static -pthread'" golang.org/x/net/http2/h2demo

View File

@ -1,8 +1,55 @@
h2demo.linux: h2demo.go # Copyright 2018 The Go Authors. All rights reserved.
GOOS=linux go build --tags=h2demo -o h2demo.linux . # Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
MUTABLE_VERSION ?= latest
VERSION ?= $(shell git rev-parse --short HEAD)
IMAGE_STAGING := gcr.io/go-dashboard-dev/h2demo
IMAGE_PROD := gcr.io/symbolic-datum-552/h2demo
DOCKER_IMAGE_build0=build0/h2demo:latest
DOCKER_CTR_build0=h2demo-build0
build0: *.go Dockerfile.0
docker build --force-rm -f Dockerfile.0 --tag=$(DOCKER_IMAGE_build0) ../..
h2demo: build0
docker create --name $(DOCKER_CTR_build0) $(DOCKER_IMAGE_build0)
docker cp $(DOCKER_CTR_build0):/go/bin/$@ $@
docker rm $(DOCKER_CTR_build0)
ca-certificates.crt:
docker create --name $(DOCKER_CTR_build0) $(DOCKER_IMAGE_build0)
docker cp $(DOCKER_CTR_build0):/etc/ssl/certs/$@ $@
docker rm $(DOCKER_CTR_build0)
update-deps:
go install golang.org/x/build/cmd/gitlock
gitlock --update=Dockerfile.0 --ignore=golang.org/x/net --tags=h2demo golang.org/x/net/http2/h2demo
docker-prod: Dockerfile h2demo ca-certificates.crt
docker build --force-rm --tag=$(IMAGE_PROD):$(VERSION) .
docker tag $(IMAGE_PROD):$(VERSION) $(IMAGE_PROD):$(MUTABLE_VERSION)
docker-staging: Dockerfile h2demo ca-certificates.crt
docker build --force-rm --tag=$(IMAGE_STAGING):$(VERSION) .
docker tag $(IMAGE_STAGING):$(VERSION) $(IMAGE_STAGING):$(MUTABLE_VERSION)
push-prod: docker-prod
gcloud docker -- push $(IMAGE_PROD):$(MUTABLE_VERSION)
gcloud docker -- push $(IMAGE_PROD):$(VERSION)
push-staging: docker-staging
gcloud docker -- push $(IMAGE_STAGING):$(MUTABLE_VERSION)
gcloud docker -- push $(IMAGE_STAGING):$(VERSION)
deploy-prod: push-prod
kubectl set image deployment/h2demo-deployment h2demo=$(IMAGE_PROD):$(VERSION)
deploy-staging: push-staging
kubectl set image deployment/h2demo-deployment h2demo=$(IMAGE_STAGING):$(VERSION)
.PHONY: clean
clean:
$(RM) h2demo
$(RM) ca-certificates.crt
FORCE: FORCE:
upload: FORCE
go install golang.org/x/build/cmd/upload
upload --verbose --osarch=linux-amd64 --tags=h2demo --file=go:golang.org/x/net/http2/h2demo --public http2-demo-server-tls/h2demo

View File

@ -0,0 +1,28 @@
apiVersion: extensions/v1beta1
kind: Deployment
metadata:
name: h2demo-deployment
spec:
replicas: 1
template:
metadata:
labels:
app: h2demo
annotations:
container.seccomp.security.alpha.kubernetes.io/h2demo: docker/default
container.apparmor.security.beta.kubernetes.io/h2demo: runtime/default
spec:
containers:
- name: h2demo
image: gcr.io/symbolic-datum-552/h2demo:latest
imagePullPolicy: Always
command: ["/h2demo", "-prod"]
ports:
- containerPort: 80
- containerPort: 443
resources:
requests:
cpu: "1"
memory: "1Gi"
limits:
memory: "2Gi"

View File

@ -8,6 +8,7 @@ package main
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
@ -19,7 +20,6 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"os"
"path" "path"
"regexp" "regexp"
"runtime" "runtime"
@ -28,7 +28,9 @@ import (
"sync" "sync"
"time" "time"
"cloud.google.com/go/storage"
"go4.org/syncutil/singleflight" "go4.org/syncutil/singleflight"
"golang.org/x/build/autocertcache"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@ -426,19 +428,10 @@ func httpHost() string {
} }
} }
func serveProdTLS() error { func serveProdTLS(autocertManager *autocert.Manager) error {
const cacheDir = "/var/cache/autocert"
if err := os.MkdirAll(cacheDir, 0700); err != nil {
return err
}
m := autocert.Manager{
Cache: autocert.DirCache(cacheDir),
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("http2.golang.org"),
}
srv := &http.Server{ srv := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
GetCertificate: m.GetCertificate, GetCertificate: autocertManager.GetCertificate,
}, },
} }
http2.ConfigureServer(srv, &http2.Server{ http2.ConfigureServer(srv, &http2.Server{
@ -468,9 +461,21 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
} }
func serveProd() error { func serveProd() error {
log.Printf("running in production mode")
storageClient, err := storage.NewClient(context.Background())
if err != nil {
log.Fatalf("storage.NewClient: %v", err)
}
autocertManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("http2.golang.org"),
Cache: autocertcache.NewGoogleCloudStorageCache(storageClient, "golang-h2demo-autocert"),
}
errc := make(chan error, 2) errc := make(chan error, 2)
go func() { errc <- http.ListenAndServe(":80", nil) }() go func() { errc <- http.ListenAndServe(":80", autocertManager.HTTPHandler(http.DefaultServeMux)) }()
go func() { errc <- serveProdTLS() }() go func() { errc <- serveProdTLS(autocertManager) }()
return <-errc return <-errc
} }

17
vendor/golang.org/x/net/http2/h2demo/service.yaml generated vendored Normal file
View File

@ -0,0 +1,17 @@
apiVersion: v1
kind: Service
metadata:
name: h2demo
spec:
externalTrafficPolicy: Local
ports:
- port: 80
targetPort: 80
name: http
- port: 443
targetPort: 443
name: https
selector:
app: h2demo
type: LoadBalancer
loadBalancerIP: 130.211.116.44

View File

@ -45,6 +45,7 @@ var (
flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.") flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.")
flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation") flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation")
flagSettings = flag.String("settings", "empty", "comma-separated list of KEY=value settings for the initial SETTINGS frame. The magic value 'empty' sends an empty initial settings frame, and the magic value 'omit' causes no initial settings frame to be sent.") flagSettings = flag.String("settings", "empty", "comma-separated list of KEY=value settings for the initial SETTINGS frame. The magic value 'empty' sends an empty initial settings frame, and the magic value 'omit' causes no initial settings frame to be sent.")
flagDial = flag.String("dial", "", "optional ip:port to dial, to connect to a host:port but use a different SNI name (including a SNI name without DNS)")
) )
type command struct { type command struct {
@ -147,11 +148,14 @@ func (app *h2i) Main() error {
InsecureSkipVerify: *flagInsecure, InsecureSkipVerify: *flagInsecure,
} }
hostAndPort := withPort(app.host) hostAndPort := *flagDial
if hostAndPort == "" {
hostAndPort = withPort(app.host)
}
log.Printf("Connecting to %s ...", hostAndPort) log.Printf("Connecting to %s ...", hostAndPort)
tc, err := tls.Dial("tcp", hostAndPort, cfg) tc, err := tls.Dial("tcp", hostAndPort, cfg)
if err != nil { if err != nil {
return fmt.Errorf("Error dialing %s: %v", withPort(app.host), err) return fmt.Errorf("Error dialing %s: %v", hostAndPort, err)
} }
log.Printf("Connected to %v", tc.RemoteAddr()) log.Printf("Connected to %v", tc.RemoteAddr())
defer tc.Close() defer tc.Close()
@ -460,6 +464,15 @@ func (app *h2i) readFrames() error {
app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField) app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField)
} }
app.hdec.Write(f.HeaderBlockFragment()) app.hdec.Write(f.HeaderBlockFragment())
case *http2.PushPromiseFrame:
if app.hdec == nil {
// TODO: if the user uses h2i to send a SETTINGS frame advertising
// something larger, we'll need to respect SETTINGS_HEADER_TABLE_SIZE
// and stuff here instead of using the 4k default. But for now:
tableSize := uint32(4 << 10)
app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField)
}
app.hdec.Write(f.HeaderBlockFragment())
} }
} }
} }

View File

@ -206,7 +206,7 @@ func appendVarInt(dst []byte, n byte, i uint64) []byte {
} }
// appendHpackString appends s, as encoded in "String Literal" // appendHpackString appends s, as encoded in "String Literal"
// representation, to dst and returns the the extended buffer. // representation, to dst and returns the extended buffer.
// //
// s will be encoded in Huffman codes only when it produces strictly // s will be encoded in Huffman codes only when it produces strictly
// shorter byte string. // shorter byte string.

View File

@ -312,7 +312,7 @@ func mustUint31(v int32) uint32 {
} }
// bodyAllowedForStatus reports whether a given response status code // bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4. // permits a body. See RFC 7230, section 3.3.
func bodyAllowedForStatus(status int) bool { func bodyAllowedForStatus(status int) bool {
switch { switch {
case status >= 100 && status <= 199: case status >= 100 && status <= 199:
@ -376,12 +376,16 @@ func (s *sorter) SortStrings(ss []string) {
// validPseudoPath reports whether v is a valid :path pseudo-header // validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either: // value. It must be either:
// //
// *) a non-empty string starting with '/', but not with with "//", // *) a non-empty string starting with '/'
// *) the string '*', for OPTIONS requests. // *) the string '*', for OPTIONS requests.
// //
// For now this is only used a quick check for deciding when to clean // For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport. // up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847 // See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool { func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" return (len(v) > 0 && v[0] == '/') || v == "*"
} }

View File

@ -25,3 +25,5 @@ func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
} }
func reqBodyIsNoBody(io.ReadCloser) bool { return false } func reqBodyIsNoBody(io.ReadCloser) bool { return false }
func go18httpNoBody() io.ReadCloser { return nil } // for tests only

View File

@ -46,6 +46,7 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
@ -220,12 +221,15 @@ func ConfigureServer(s *http.Server, conf *Server) error {
} else if s.TLSConfig.CipherSuites != nil { } else if s.TLSConfig.CipherSuites != nil {
// If they already provided a CipherSuite list, return // If they already provided a CipherSuite list, return
// an error if it has a bad order or is missing // an error if it has a bad order or is missing
// ECDHE_RSA_WITH_AES_128_GCM_SHA256. // ECDHE_RSA_WITH_AES_128_GCM_SHA256 or ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.
const requiredCipher = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
haveRequired := false haveRequired := false
sawBad := false sawBad := false
for i, cs := range s.TLSConfig.CipherSuites { for i, cs := range s.TLSConfig.CipherSuites {
if cs == requiredCipher { switch cs {
case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
// Alternative MTI cipher to not discourage ECDSA-only servers.
// See http://golang.org/cl/30721 for further information.
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
haveRequired = true haveRequired = true
} }
if isBadCipher(cs) { if isBadCipher(cs) {
@ -235,7 +239,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
} }
} }
if !haveRequired { if !haveRequired {
return fmt.Errorf("http2: TLSConfig.CipherSuites is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.")
} }
} }
@ -403,7 +407,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// addresses during development. // addresses during development.
// //
// TODO: optionally enforce? Or enforce at the time we receive // TODO: optionally enforce? Or enforce at the time we receive
// a new request, and verify the the ServerName matches the :authority? // a new request, and verify the ServerName matches the :authority?
// But that precludes proxy situations, perhaps. // But that precludes proxy situations, perhaps.
// //
// So for now, do nothing here again. // So for now, do nothing here again.
@ -649,7 +653,7 @@ func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
if err == nil { if err == nil {
return return
} }
if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) { if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout {
// Boring, expected errors. // Boring, expected errors.
sc.vlogf(format, args...) sc.vlogf(format, args...)
} else { } else {
@ -853,8 +857,13 @@ func (sc *serverConn) serve() {
} }
} }
if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY
return // with no error code (graceful shutdown), don't start the timer until
// all open streams have been completed.
sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0
if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) {
sc.shutDownIn(goAwayTimeout)
} }
} }
} }
@ -889,8 +898,11 @@ func (sc *serverConn) sendServeMsg(msg interface{}) {
} }
} }
// readPreface reads the ClientPreface greeting from the peer var errPrefaceTimeout = errors.New("timeout waiting for client preface")
// or returns an error on timeout or an invalid greeting.
// readPreface reads the ClientPreface greeting from the peer or
// returns errPrefaceTimeout on timeout, or an error if the greeting
// is invalid.
func (sc *serverConn) readPreface() error { func (sc *serverConn) readPreface() error {
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -908,7 +920,7 @@ func (sc *serverConn) readPreface() error {
defer timer.Stop() defer timer.Stop()
select { select {
case <-timer.C: case <-timer.C:
return errors.New("timeout waiting for client preface") return errPrefaceTimeout
case err := <-errc: case err := <-errc:
if err == nil { if err == nil {
if VerboseLogs { if VerboseLogs {
@ -1218,30 +1230,31 @@ func (sc *serverConn) startGracefulShutdown() {
sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) }) sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
} }
// After sending GOAWAY, the connection will close after goAwayTimeout.
// If we close the connection immediately after sending GOAWAY, there may
// be unsent data in our kernel receive buffer, which will cause the kernel
// to send a TCP RST on close() instead of a FIN. This RST will abort the
// connection immediately, whether or not the client had received the GOAWAY.
//
// Ideally we should delay for at least 1 RTT + epsilon so the client has
// a chance to read the GOAWAY and stop sending messages. Measuring RTT
// is hard, so we approximate with 1 second. See golang.org/issue/18701.
//
// This is a var so it can be shorter in tests, where all requests uses the
// loopback interface making the expected RTT very small.
//
// TODO: configurable?
var goAwayTimeout = 1 * time.Second
func (sc *serverConn) startGracefulShutdownInternal() { func (sc *serverConn) startGracefulShutdownInternal() {
sc.goAwayIn(ErrCodeNo, 0) sc.goAway(ErrCodeNo)
} }
func (sc *serverConn) goAway(code ErrCode) { func (sc *serverConn) goAway(code ErrCode) {
sc.serveG.check()
var forceCloseIn time.Duration
if code != ErrCodeNo {
forceCloseIn = 250 * time.Millisecond
} else {
// TODO: configurable
forceCloseIn = 1 * time.Second
}
sc.goAwayIn(code, forceCloseIn)
}
func (sc *serverConn) goAwayIn(code ErrCode, forceCloseIn time.Duration) {
sc.serveG.check() sc.serveG.check()
if sc.inGoAway { if sc.inGoAway {
return return
} }
if forceCloseIn != 0 {
sc.shutDownIn(forceCloseIn)
}
sc.inGoAway = true sc.inGoAway = true
sc.needToSendGoAway = true sc.needToSendGoAway = true
sc.goAwayCode = code sc.goAwayCode = code
@ -1805,7 +1818,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
if st.trailer != nil { if st.trailer != nil {
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
key := sc.canonicalHeader(hf.Name) key := sc.canonicalHeader(hf.Name)
if !ValidTrailerHeader(key) { if !httpguts.ValidTrailerHeader(key) {
// TODO: send more details to the peer somehow. But http2 has // TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with // no way to send debug data at a stream level. Discuss with
// HTTP folk. // HTTP folk.
@ -2252,6 +2265,7 @@ type responseWriterState struct {
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame? sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished handlerDone bool // handler has finished
dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64 wroteBytes int64
@ -2271,8 +2285,8 @@ func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) !=
// written in the trailers at the end of the response. // written in the trailers at the end of the response.
func (rws *responseWriterState) declareTrailer(k string) { func (rws *responseWriterState) declareTrailer(k string) {
k = http.CanonicalHeaderKey(k) k = http.CanonicalHeaderKey(k)
if !ValidTrailerHeader(k) { if !httpguts.ValidTrailerHeader(k) {
// Forbidden by RFC 2616 14.40. // Forbidden by RFC 7230, section 4.1.2.
rws.conn.logf("ignoring invalid trailer %q", k) rws.conn.logf("ignoring invalid trailer %q", k)
return return
} }
@ -2309,8 +2323,16 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
clen = strconv.Itoa(len(p)) clen = strconv.Itoa(len(p))
} }
_, hasContentType := rws.snapHeader["Content-Type"] _, hasContentType := rws.snapHeader["Content-Type"]
if !hasContentType && bodyAllowedForStatus(rws.status) { if !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 {
ctype = http.DetectContentType(p) if cto := rws.snapHeader.Get("X-Content-Type-Options"); strings.EqualFold("nosniff", cto) {
// nosniff is an explicit directive not to guess a content-type.
// Content-sniffing is no less susceptible to polyglot attacks via
// hosted content when done on the server.
ctype = "application/octet-stream"
rws.conn.logf("http2: WriteHeader called with X-Content-Type-Options:nosniff but no Content-Type")
} else {
ctype = http.DetectContentType(p)
}
} }
var date string var date string
if _, ok := rws.snapHeader["Date"]; !ok { if _, ok := rws.snapHeader["Date"]; !ok {
@ -2333,6 +2355,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
date: date, date: date,
}) })
if err != nil { if err != nil {
rws.dirty = true
return 0, err return 0, err
} }
if endStream { if endStream {
@ -2354,6 +2377,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if len(p) > 0 || endStream { if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream. // only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
rws.dirty = true
return 0, err return 0, err
} }
} }
@ -2365,6 +2389,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
trailers: rws.trailers, trailers: rws.trailers,
endStream: true, endStream: true,
}) })
if err != nil {
rws.dirty = true
}
return len(p), err return len(p), err
} }
return len(p), nil return len(p), nil
@ -2388,7 +2415,7 @@ const TrailerPrefix = "Trailer:"
// after the header has already been flushed. Because the Go // after the header has already been flushed. Because the Go
// ResponseWriter interface has no way to set Trailers (only the // ResponseWriter interface has no way to set Trailers (only the
// Header), and because we didn't want to expand the ResponseWriter // Header), and because we didn't want to expand the ResponseWriter
// interface, and because nobody used trailers, and because RFC 2616 // interface, and because nobody used trailers, and because RFC 7230
// says you SHOULD (but not must) predeclare any trailers in the // says you SHOULD (but not must) predeclare any trailers in the
// header, the official ResponseWriter rules said trailers in Go must // header, the official ResponseWriter rules said trailers in Go must
// be predeclared, and then we reuse the same ResponseWriter.Header() // be predeclared, and then we reuse the same ResponseWriter.Header()
@ -2472,6 +2499,24 @@ func (w *responseWriter) Header() http.Header {
return rws.handlerHeader return rws.handlerHeader
} }
// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode.
func checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
// at http://httpwg.org/specs/rfc7231.html#status.codes)
// and we might block under 200 (once we have more mature 1xx support).
// But for now any three digits.
//
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
// no equivalent bogus thing we can realistically send in HTTP/2,
// so we'll consistently panic instead and help people find their bugs
// early. (We can't return an error from WriteHeader even if we wanted to.)
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
func (w *responseWriter) WriteHeader(code int) { func (w *responseWriter) WriteHeader(code int) {
rws := w.rws rws := w.rws
if rws == nil { if rws == nil {
@ -2482,6 +2527,7 @@ func (w *responseWriter) WriteHeader(code int) {
func (rws *responseWriterState) writeHeader(code int) { func (rws *responseWriterState) writeHeader(code int) {
if !rws.wroteHeader { if !rws.wroteHeader {
checkWriteHeaderCode(code)
rws.wroteHeader = true rws.wroteHeader = true
rws.status = code rws.status = code
if len(rws.handlerHeader) > 0 { if len(rws.handlerHeader) > 0 {
@ -2504,7 +2550,7 @@ func cloneHeader(h http.Header) http.Header {
// //
// * Handler calls w.Write or w.WriteString -> // * Handler calls w.Write or w.WriteString ->
// * -> rws.bw (*bufio.Writer) -> // * -> rws.bw (*bufio.Writer) ->
// * (Handler migth call Flush) // * (Handler might call Flush)
// * -> chunkWriter{rws} // * -> chunkWriter{rws}
// * -> responseWriterState.writeChunk(p []byte) // * -> responseWriterState.writeChunk(p []byte)
// * -> responseWriterState.writeChunk (most of the magic; see comment there) // * -> responseWriterState.writeChunk (most of the magic; see comment there)
@ -2543,10 +2589,19 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
func (w *responseWriter) handlerDone() { func (w *responseWriter) handlerDone() {
rws := w.rws rws := w.rws
dirty := rws.dirty
rws.handlerDone = true rws.handlerDone = true
w.Flush() w.Flush()
w.rws = nil w.rws = nil
responseWriterStatePool.Put(rws) if !dirty {
// Only recycle the pool if all prior Write calls to
// the serverConn goroutine completed successfully. If
// they returned earlier due to resets from the peer
// there might still be write goroutines outstanding
// from the serverConn referencing the rws memory. See
// issue 20704.
responseWriterStatePool.Put(rws)
}
} }
// Push errors. // Push errors.
@ -2744,7 +2799,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
} }
// foreachHeaderElement splits v according to the "#rule" construction // foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element. // in RFC 7230 section 7 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) { func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v) v = textproto.TrimString(v)
if v == "" { if v == "" {
@ -2792,41 +2847,6 @@ func new400Handler(err error) http.HandlerFunc {
} }
} }
// ValidTrailerHeader reports whether name is a valid header field name to appear
// in trailers.
// See: http://tools.ietf.org/html/rfc7230#section-4.1.2
func ValidTrailerHeader(name string) bool {
name = http.CanonicalHeaderKey(name)
if strings.HasPrefix(name, "If-") || badTrailer[name] {
return false
}
return true
}
var badTrailer = map[string]bool{
"Authorization": true,
"Cache-Control": true,
"Connection": true,
"Content-Encoding": true,
"Content-Length": true,
"Content-Range": true,
"Content-Type": true,
"Expect": true,
"Host": true,
"Keep-Alive": true,
"Max-Forwards": true,
"Pragma": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
"Range": true,
"Realm": true,
"Te": true,
"Trailer": true,
"Transfer-Encoding": true,
"Www-Authenticate": true,
}
// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
// disabled. See comments on h1ServerShutdownChan above for why // disabled. See comments on h1ServerShutdownChan above for why
// the code is written this way. // the code is written this way.

View File

@ -68,6 +68,7 @@ type serverTester struct {
func init() { func init() {
testHookOnPanicMu = new(sync.Mutex) testHookOnPanicMu = new(sync.Mutex)
goAwayTimeout = 25 * time.Millisecond
} }
func resetHooks() { func resetHooks() {
@ -286,7 +287,7 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error
case *WindowUpdateFrame: case *WindowUpdateFrame:
if f.FrameHeader.StreamID != 0 { if f.FrameHeader.StreamID != 0 {
st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID, 0) st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
} }
incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize) incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize)
if f.Increment != incr { if f.Increment != incr {
@ -1717,7 +1718,6 @@ func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"foo-bar", "some-value"}, {"foo-bar", "some-value"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {
@ -1760,6 +1760,42 @@ func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
}) })
} }
func TestServer_Response_Nosniff_WithoutContentType(t *testing.T) {
const msg = "<html>this is HTML."
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(200)
io.WriteString(w, msg)
return nil
}, func(st *serverTester) {
getSlash(st)
hf := st.wantHeaders()
if hf.StreamEnded() {
t.Fatal("don't want END_STREAM, expecting data")
}
if !hf.HeadersEnded() {
t.Fatal("want END_HEADERS flag")
}
goth := st.decodeHeader(hf.HeaderBlockFragment())
wanth := [][2]string{
{":status", "200"},
{"x-content-type-options", "nosniff"},
{"content-type", "application/octet-stream"},
{"content-length", strconv.Itoa(len(msg))},
}
if !reflect.DeepEqual(goth, wanth) {
t.Errorf("Got headers %v; want %v", goth, wanth)
}
df := st.wantData()
if !df.StreamEnded() {
t.Error("expected DATA to have END_STREAM flag")
}
if got := string(df.Data()); got != msg {
t.Errorf("got DATA %q; want %q", got, msg)
}
})
}
func TestServer_Response_TransferEncoding_chunked(t *testing.T) { func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
const msg = "hi" const msg = "hi"
testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
@ -2877,9 +2913,9 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
w.Header().Set("Trailer:post-header-trailer2", "hi2") w.Header().Set("Trailer:post-header-trailer2", "hi2")
w.Header().Set("Trailer:Range", "invalid") w.Header().Set("Trailer:Range", "invalid")
w.Header().Set("Trailer:Foo\x01Bogus", "invalid") w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40") w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40") w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40") w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
return nil return nil
}, func(st *serverTester) { }, func(st *serverTester) {
getSlash(st) getSlash(st)
@ -2952,7 +2988,6 @@ func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"ok1", "x"}, {"ok1", "x"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {
@ -2972,7 +3007,7 @@ func BenchmarkServerGets(b *testing.B) {
defer st.Close() defer st.Close()
st.greet() st.greet()
// Give the server quota to reply. (plus it has the the 64KB) // Give the server quota to reply. (plus it has the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -3010,7 +3045,7 @@ func BenchmarkServerPosts(b *testing.B) {
defer st.Close() defer st.Close()
st.greet() st.greet()
// Give the server quota to reply. (plus it has the the 64KB) // Give the server quota to reply. (plus it has the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -3188,12 +3223,18 @@ func TestConfigureServer(t *testing.T) {
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
}, },
}, },
{
name: "just the alternative required cipher suite",
tlsConfig: &tls.Config{
CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
},
},
{ {
name: "missing required cipher suite", name: "missing required cipher suite",
tlsConfig: &tls.Config{ tlsConfig: &tls.Config{
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384}, CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
}, },
wantErr: "is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", wantErr: "is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.",
}, },
{ {
name: "required after bad", name: "required after bad",
@ -3259,7 +3300,6 @@ func TestServerNoAutoContentLengthOnHead(t *testing.T) {
headers := st.decodeHeader(h.HeaderBlockFragment()) headers := st.decodeHeader(h.HeaderBlockFragment())
want := [][2]string{ want := [][2]string{
{":status", "200"}, {":status", "200"},
{"content-type", "text/plain; charset=utf-8"},
} }
if !reflect.DeepEqual(headers, want) { if !reflect.DeepEqual(headers, want) {
t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
@ -3312,7 +3352,7 @@ func BenchmarkServer_GetRequest(b *testing.B) {
defer st.Close() defer st.Close()
st.greet() st.greet()
// Give the server quota to reply. (plus it has the the 64KB) // Give the server quota to reply. (plus it has the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -3343,7 +3383,7 @@ func BenchmarkServer_PostRequest(b *testing.B) {
}) })
defer st.Close() defer st.Close()
st.greet() st.greet()
// Give the server quota to reply. (plus it has the the 64KB) // Give the server quota to reply. (plus it has the 64KB)
if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -3685,3 +3725,37 @@ func TestRequestBodyReadCloseRace(t *testing.T) {
<-done <-done
} }
} }
func TestIssue20704Race(t *testing.T) {
if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
t.Skip("skipping in short mode")
}
const (
itemSize = 1 << 10
itemCount = 100
)
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
for i := 0; i < itemCount; i++ {
_, err := w.Write(make([]byte, itemSize))
if err != nil {
return
}
}
}, optOnlyServer)
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cl := &http.Client{Transport: tr}
for i := 0; i < 1000; i++ {
resp, err := cl.Get(st.ts.URL)
if err != nil {
t.Fatal(err)
}
// Force a RST stream to the server by closing without
// reading the body:
resp.Body.Close()
}
}

View File

@ -18,6 +18,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"math" "math"
mathrand "math/rand"
"net" "net"
"net/http" "net/http"
"sort" "sort"
@ -86,7 +87,7 @@ type Transport struct {
// MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to
// send in the initial settings frame. It is how many bytes // send in the initial settings frame. It is how many bytes
// of response headers are allow. Unlike the http2 spec, zero here // of response headers are allowed. Unlike the http2 spec, zero here
// means to use a default limit (currently 10MB). If you actually // means to use a default limit (currently 10MB). If you actually
// want to advertise an ulimited value to the peer, Transport // want to advertise an ulimited value to the peer, Transport
// interprets the highest possible value here (0xffffffff or 1<<32-1) // interprets the highest possible value here (0xffffffff or 1<<32-1)
@ -164,15 +165,17 @@ type ClientConn struct {
goAwayDebug string // goAway frame's debug data, retained as a string goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*clientStream // client-initiated streams map[uint32]*clientStream // client-initiated
nextStreamID uint32 nextStreamID uint32
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel pings map[[8]byte]chan struct{} // in flight ping data to notification channel
bw *bufio.Writer bw *bufio.Writer
br *bufio.Reader br *bufio.Reader
fr *Framer fr *Framer
lastActive time.Time lastActive time.Time
// Settings from peer: (also guarded by mu) // Settings from peer: (also guarded by mu)
maxFrameSize uint32 maxFrameSize uint32
maxConcurrentStreams uint32 maxConcurrentStreams uint32
initialWindowSize uint32 peerMaxHeaderListSize uint64
initialWindowSize uint32
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder henc *hpack.Encoder
@ -216,35 +219,45 @@ type clientStream struct {
resTrailer *http.Header // client's Response.Trailer resTrailer *http.Header // client's Response.Trailer
} }
// awaitRequestCancel runs in its own goroutine and waits for the user // awaitRequestCancel waits for the user to cancel a request or for the done
// to cancel a RoundTrip request, its context to expire, or for the // channel to be signaled. A non-nil error is returned only if the request was
// request to be done (any way it might be removed from the cc.streams // canceled.
// map: peer reset, successful completion, TCP connection breakage, func awaitRequestCancel(req *http.Request, done <-chan struct{}) error {
// etc)
func (cs *clientStream) awaitRequestCancel(req *http.Request) {
ctx := reqContext(req) ctx := reqContext(req)
if req.Cancel == nil && ctx.Done() == nil { if req.Cancel == nil && ctx.Done() == nil {
return return nil
} }
select { select {
case <-req.Cancel: case <-req.Cancel:
cs.cancelStream() return errRequestCanceled
cs.bufPipe.CloseWithError(errRequestCanceled)
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
// awaitRequestCancel waits for the user to cancel a request, its context to
// expire, or for the request to be done (any way it might be removed from the
// cc.streams map: peer reset, successful completion, TCP connection breakage,
// etc). If the request is canceled, then cs will be canceled and closed.
func (cs *clientStream) awaitRequestCancel(req *http.Request) {
if err := awaitRequestCancel(req, cs.done); err != nil {
cs.cancelStream() cs.cancelStream()
cs.bufPipe.CloseWithError(ctx.Err()) cs.bufPipe.CloseWithError(err)
case <-cs.done:
} }
} }
func (cs *clientStream) cancelStream() { func (cs *clientStream) cancelStream() {
cs.cc.mu.Lock() cc := cs.cc
cc.mu.Lock()
didReset := cs.didReset didReset := cs.didReset
cs.didReset = true cs.didReset = true
cs.cc.mu.Unlock() cc.mu.Unlock()
if !didReset { if !didReset {
cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
cc.forgetStreamID(cs.ID)
} }
} }
@ -261,6 +274,13 @@ func (cs *clientStream) checkResetOrDone() error {
} }
} }
func (cs *clientStream) getStartedWrite() bool {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
return cs.startedWrite
}
func (cs *clientStream) abortRequestBodyWrite(err error) { func (cs *clientStream) abortRequestBodyWrite(err error) {
if err == nil { if err == nil {
panic("nil error") panic("nil error")
@ -286,7 +306,26 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
return return
} }
var ErrNoCachedConn = errors.New("http2: no cached connection was available") // noCachedConnError is the concrete type of ErrNoCachedConn, which
// needs to be detected by net/http regardless of whether it's its
// bundled version (in h2_bundle.go with a rewritten type name) or
// from a user's x/net/http2. As such, as it has a unique method name
// (IsHTTP2NoCachedConnError) that net/http sniffs for via func
// isNoCachedConnError.
type noCachedConnError struct{}
func (noCachedConnError) IsHTTP2NoCachedConnError() {}
func (noCachedConnError) Error() string { return "http2: no cached connection was available" }
// isNoCachedConnError reports whether err is of type noCachedConnError
// or its equivalent renamed type in net/http2's h2_bundle.go. Both types
// may coexist in the same running program.
func isNoCachedConnError(err error) bool {
_, ok := err.(interface{ IsHTTP2NoCachedConnError() })
return ok
}
var ErrNoCachedConn error = noCachedConnError{}
// RoundTripOpt are options for the Transport.RoundTripOpt method. // RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct { type RoundTripOpt struct {
@ -329,17 +368,28 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
} }
addr := authorityAddr(req.URL.Scheme, req.URL.Host) addr := authorityAddr(req.URL.Scheme, req.URL.Host)
for { for retry := 0; ; retry++ {
cc, err := t.connPool().GetClientConn(req, addr) cc, err := t.connPool().GetClientConn(req, addr)
if err != nil { if err != nil {
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err return nil, err
} }
traceGotConn(req, cc) traceGotConn(req, cc)
res, err := cc.RoundTrip(req) res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
if err != nil { if err != nil && retry <= 6 {
if req, err = shouldRetryRequest(req, err); err == nil { if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
continue // After the first retry, do exponential backoff with 10% jitter.
if retry == 0 {
continue
}
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
select {
case <-time.After(time.Second * time.Duration(backoff)):
continue
case <-reqContext(req).Done():
return nil, reqContext(req).Err()
}
} }
} }
if err != nil { if err != nil {
@ -360,43 +410,50 @@ func (t *Transport) CloseIdleConnections() {
} }
var ( var (
errClientConnClosed = errors.New("http2: client conn is closed") errClientConnClosed = errors.New("http2: client conn is closed")
errClientConnUnusable = errors.New("http2: client conn not usable") errClientConnUnusable = errors.New("http2: client conn not usable")
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
) )
// shouldRetryRequest is called by RoundTrip when a request fails to get // shouldRetryRequest is called by RoundTrip when a request fails to get
// response headers. It is always called with a non-nil error. // response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a // It returns either a request to retry (either the same request, or a
// modified clone), or an error if the request can't be replayed. // modified clone), or an error if the request can't be replayed.
func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*http.Request, error) {
switch err { if !canRetryError(err) {
default:
return nil, err return nil, err
case errClientConnUnusable, errClientConnGotGoAway:
return req, nil
case errClientConnGotGoAwayAfterSomeReqBody:
// If the Body is nil (or http.NoBody), it's safe to reuse
// this request and its Body.
if req.Body == nil || reqBodyIsNoBody(req.Body) {
return req, nil
}
// Otherwise we depend on the Request having its GetBody
// func defined.
getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
if getBody == nil {
return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
}
body, err := getBody()
if err != nil {
return nil, err
}
newReq := *req
newReq.Body = body
return &newReq, nil
} }
if !afterBodyWrite {
return req, nil
}
// If the Body is nil (or http.NoBody), it's safe to reuse
// this request and its Body.
if req.Body == nil || reqBodyIsNoBody(req.Body) {
return req, nil
}
// Otherwise we depend on the Request having its GetBody
// func defined.
getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
if getBody == nil {
return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
}
body, err := getBody()
if err != nil {
return nil, err
}
newReq := *req
newReq.Body = body
return &newReq, nil
}
func canRetryError(err error) bool {
if err == errClientConnUnusable || err == errClientConnGotGoAway {
return true
}
if se, ok := err.(StreamError); ok {
return se.Code == ErrCodeRefusedStream
}
return false
} }
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) { func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
@ -474,17 +531,18 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
t: t, t: t,
tconn: c, tconn: c,
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
nextStreamID: 1, nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default initialWindowSize: 65535, // spec default
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
streams: make(map[uint32]*clientStream), peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
singleUse: singleUse, streams: make(map[uint32]*clientStream),
wantSettingsAck: true, singleUse: singleUse,
pings: make(map[[8]byte]chan struct{}), wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
} }
if d := t.idleConnTimeout(); d != 0 { if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d cc.idleTimeout = d
@ -560,6 +618,8 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
} }
} }
// CanTakeNewRequest reports whether the connection can take a new request,
// meaning it has not been closed or received or sent a GOAWAY.
func (cc *ClientConn) CanTakeNewRequest() bool { func (cc *ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
@ -571,8 +631,7 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool {
return false return false
} }
return cc.goAway == nil && !cc.closed && return cc.goAway == nil && !cc.closed &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
cc.nextStreamID < math.MaxInt32
} }
// onIdleTimeout is called from a time.AfterFunc goroutine. It will // onIdleTimeout is called from a time.AfterFunc goroutine. It will
@ -694,7 +753,7 @@ func checkConnHeaders(req *http.Request) error {
// req.ContentLength, where 0 actually means zero (not unknown) and -1 // req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown. // means unknown.
func actualContentLength(req *http.Request) int64 { func actualContentLength(req *http.Request) int64 {
if req.Body == nil { if req.Body == nil || reqBodyIsNoBody(req.Body) {
return 0 return 0
} }
if req.ContentLength != 0 { if req.ContentLength != 0 {
@ -704,8 +763,13 @@ func actualContentLength(req *http.Request) int64 {
} }
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
resp, _, err := cc.roundTrip(req)
return resp, err
}
func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
if err := checkConnHeaders(req); err != nil { if err := checkConnHeaders(req); err != nil {
return nil, err return nil, false, err
} }
if cc.idleTimer != nil { if cc.idleTimer != nil {
cc.idleTimer.Stop() cc.idleTimer.Stop()
@ -713,20 +777,19 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
trailers, err := commaSeparatedTrailers(req) trailers, err := commaSeparatedTrailers(req)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
hasTrailers := trailers != "" hasTrailers := trailers != ""
cc.mu.Lock() cc.mu.Lock()
cc.lastActive = time.Now() if err := cc.awaitOpenSlotForRequest(req); err != nil {
if cc.closed || !cc.canTakeNewRequestLocked() {
cc.mu.Unlock() cc.mu.Unlock()
return nil, errClientConnUnusable return nil, false, err
} }
body := req.Body body := req.Body
hasBody := body != nil
contentLen := actualContentLength(req) contentLen := actualContentLength(req)
hasBody := contentLen != 0
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
var requestedGzip bool var requestedGzip bool
@ -755,7 +818,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
if err != nil { if err != nil {
cc.mu.Unlock() cc.mu.Unlock()
return nil, err return nil, false, err
} }
cs := cc.newStream() cs := cc.newStream()
@ -767,7 +830,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cc.wmu.Lock() cc.wmu.Lock()
endStream := !hasBody && !hasTrailers endStream := !hasBody && !hasTrailers
werr := cc.writeHeaders(cs.ID, endStream, hdrs) werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
cc.wmu.Unlock() cc.wmu.Unlock()
traceWroteHeaders(cs.trace) traceWroteHeaders(cs.trace)
cc.mu.Unlock() cc.mu.Unlock()
@ -781,7 +844,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
// Don't bother sending a RST_STREAM (our write already failed; // Don't bother sending a RST_STREAM (our write already failed;
// no need to keep writing) // no need to keep writing)
traceWroteRequest(cs.trace, werr) traceWroteRequest(cs.trace, werr)
return nil, werr return nil, false, werr
} }
var respHeaderTimer <-chan time.Time var respHeaderTimer <-chan time.Time
@ -800,7 +863,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
bodyWritten := false bodyWritten := false
ctx := reqContext(req) ctx := reqContext(req)
handleReadLoopResponse := func(re resAndError) (*http.Response, error) { handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
res := re.res res := re.res
if re.err != nil || res.StatusCode > 299 { if re.err != nil || res.StatusCode > 299 {
// On error or status code 3xx, 4xx, 5xx, etc abort any // On error or status code 3xx, 4xx, 5xx, etc abort any
@ -816,19 +879,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWrite) cs.abortRequestBodyWrite(errStopReqBodyWrite)
} }
if re.err != nil { if re.err != nil {
if re.err == errClientConnGotGoAway {
cc.mu.Lock()
if cs.startedWrite {
re.err = errClientConnGotGoAwayAfterSomeReqBody
}
cc.mu.Unlock()
}
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
return nil, re.err return nil, cs.getStartedWrite(), re.err
} }
res.Request = req res.Request = req
res.TLS = cc.tlsState res.TLS = cc.tlsState
return res, nil return res, false, nil
} }
for { for {
@ -836,37 +892,37 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
case re := <-readLoopResCh: case re := <-readLoopResCh:
return handleReadLoopResponse(re) return handleReadLoopResponse(re)
case <-respHeaderTimer: case <-respHeaderTimer:
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten { if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else { } else {
bodyWriter.cancel() bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
return nil, errTimeout cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errTimeout
case <-ctx.Done(): case <-ctx.Done():
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten { if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else { } else {
bodyWriter.cancel() bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
return nil, ctx.Err() cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), ctx.Err()
case <-req.Cancel: case <-req.Cancel:
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten { if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else { } else {
bodyWriter.cancel() bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
} }
return nil, errRequestCanceled cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errRequestCanceled
case <-cs.peerReset: case <-cs.peerReset:
// processResetStream already removed the // processResetStream already removed the
// stream from the streams map; no need for // stream from the streams map; no need for
// forgetStreamID. // forgetStreamID.
return nil, cs.resetErr return nil, cs.getStartedWrite(), cs.resetErr
case err := <-bodyWriter.resc: case err := <-bodyWriter.resc:
// Prefer the read loop's response, if available. Issue 16102. // Prefer the read loop's response, if available. Issue 16102.
select { select {
@ -875,7 +931,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
default: default:
} }
if err != nil { if err != nil {
return nil, err return nil, cs.getStartedWrite(), err
} }
bodyWritten = true bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 { if d := cc.responseHeaderTimeout(); d != 0 {
@ -887,14 +943,55 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
} }
} }
// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams.
// Must hold cc.mu.
func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error {
var waitingForConn chan struct{}
var waitingForConnErr error // guarded by cc.mu
for {
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
if waitingForConn != nil {
close(waitingForConn)
}
return errClientConnUnusable
}
if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) {
if waitingForConn != nil {
close(waitingForConn)
}
return nil
}
// Unfortunately, we cannot wait on a condition variable and channel at
// the same time, so instead, we spin up a goroutine to check if the
// request is canceled while we wait for a slot to open in the connection.
if waitingForConn == nil {
waitingForConn = make(chan struct{})
go func() {
if err := awaitRequestCancel(req, waitingForConn); err != nil {
cc.mu.Lock()
waitingForConnErr = err
cc.cond.Broadcast()
cc.mu.Unlock()
}
}()
}
cc.pendingRequests++
cc.cond.Wait()
cc.pendingRequests--
if waitingForConnErr != nil {
return waitingForConnErr
}
}
}
// requires cc.wmu be held // requires cc.wmu be held
func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error {
first := true // first frame written (HEADERS is first, then CONTINUATION) first := true // first frame written (HEADERS is first, then CONTINUATION)
frameSize := int(cc.maxFrameSize)
for len(hdrs) > 0 && cc.werr == nil { for len(hdrs) > 0 && cc.werr == nil {
chunk := hdrs chunk := hdrs
if len(chunk) > frameSize { if len(chunk) > maxFrameSize {
chunk = chunk[:frameSize] chunk = chunk[:maxFrameSize]
} }
hdrs = hdrs[len(chunk):] hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0 endHeaders := len(hdrs) == 0
@ -1002,17 +1099,26 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
var trls []byte var trls []byte
if hasTrailers { if hasTrailers {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() trls, err = cc.encodeTrailers(req)
trls = cc.encodeTrailers(req) cc.mu.Unlock()
if err != nil {
cc.writeStreamReset(cs.ID, ErrCodeInternal, err)
cc.forgetStreamID(cs.ID)
return err
}
} }
cc.mu.Lock()
maxFrameSize := int(cc.maxFrameSize)
cc.mu.Unlock()
cc.wmu.Lock() cc.wmu.Lock()
defer cc.wmu.Unlock() defer cc.wmu.Unlock()
// Two ways to send END_STREAM: either with trailers, or // Two ways to send END_STREAM: either with trailers, or
// with an empty DATA frame. // with an empty DATA frame.
if len(trls) > 0 { if len(trls) > 0 {
err = cc.writeHeaders(cs.ID, true, trls) err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls)
} else { } else {
err = cc.fr.WriteData(cs.ID, true, nil) err = cc.fr.WriteData(cs.ID, true, nil)
} }
@ -1106,62 +1212,86 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
} }
} }
// 8.1.2.3 Request Pseudo-Header Fields enumerateHeaders := func(f func(name, value string)) {
// The :path pseudo-header field includes the path and query parts of the // 8.1.2.3 Request Pseudo-Header Fields
// target URI (the path-absolute production and optionally a '?' character // The :path pseudo-header field includes the path and query parts of the
// followed by the query production (see Sections 3.3 and 3.4 of // target URI (the path-absolute production and optionally a '?' character
// [RFC3986]). // followed by the query production (see Sections 3.3 and 3.4 of
cc.writeHeader(":authority", host) // [RFC3986]).
cc.writeHeader(":method", req.Method) f(":authority", host)
if req.Method != "CONNECT" { f(":method", req.Method)
cc.writeHeader(":path", path) if req.Method != "CONNECT" {
cc.writeHeader(":scheme", req.URL.Scheme) f(":path", path)
} f(":scheme", req.URL.Scheme)
if trailers != "" { }
cc.writeHeader("trailer", trailers) if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
strings.EqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if strings.EqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
f("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", defaultUserAgent)
}
} }
var didUA bool // Do a first pass over the headers counting bytes to ensure
for k, vv := range req.Header { // we don't exceed cc.peerMaxHeaderListSize. This is done as a
lowKey := strings.ToLower(k) // separate pass before encoding the headers to prevent
switch lowKey { // modifying the hpack state.
case "host", "content-length": hlSize := uint64(0)
// Host is :authority, already sent. enumerateHeaders(func(name, value string) {
// Content-Length is automatic, set below. hf := hpack.HeaderField{Name: name, Value: value}
continue hlSize += uint64(hf.Size())
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": })
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific if hlSize > cc.peerMaxHeaderListSize {
// fields. We have already checked if any return nil, errRequestHeaderListSize
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
cc.writeHeader("accept-encoding", "gzip")
}
if !didUA {
cc.writeHeader("user-agent", defaultUserAgent)
} }
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
cc.writeHeader(strings.ToLower(name), value)
})
return cc.hbuf.Bytes(), nil return cc.hbuf.Bytes(), nil
} }
@ -1188,17 +1318,29 @@ func shouldSendReqContentLength(method string, contentLength int64) bool {
} }
// requires cc.mu be held. // requires cc.mu be held.
func (cc *ClientConn) encodeTrailers(req *http.Request) []byte { func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) {
cc.hbuf.Reset() cc.hbuf.Reset()
hlSize := uint64(0)
for k, vv := range req.Trailer { for k, vv := range req.Trailer {
// Transfer-Encoding, etc.. have already been filter at the for _, v := range vv {
hf := hpack.HeaderField{Name: k, Value: v}
hlSize += uint64(hf.Size())
}
}
if hlSize > cc.peerMaxHeaderListSize {
return nil, errRequestHeaderListSize
}
for k, vv := range req.Trailer {
// Transfer-Encoding, etc.. have already been filtered at the
// start of RoundTrip // start of RoundTrip
lowKey := strings.ToLower(k) lowKey := strings.ToLower(k)
for _, v := range vv { for _, v := range vv {
cc.writeHeader(lowKey, v) cc.writeHeader(lowKey, v)
} }
} }
return cc.hbuf.Bytes() return cc.hbuf.Bytes(), nil
} }
func (cc *ClientConn) writeHeader(name, value string) { func (cc *ClientConn) writeHeader(name, value string) {
@ -1246,7 +1388,9 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
cc.idleTimer.Reset(cc.idleTimeout) cc.idleTimer.Reset(cc.idleTimeout)
} }
close(cs.done) close(cs.done)
cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl // Wake up checkResetOrDone via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
cc.cond.Broadcast()
} }
return cs return cs
} }
@ -1254,17 +1398,12 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type clientConnReadLoop struct { type clientConnReadLoop struct {
cc *ClientConn cc *ClientConn
activeRes map[uint32]*clientStream // keyed by streamID
closeWhenIdle bool closeWhenIdle bool
} }
// readLoop runs in its own goroutine and reads and dispatches frames. // readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *ClientConn) readLoop() { func (cc *ClientConn) readLoop() {
rl := &clientConnReadLoop{ rl := &clientConnReadLoop{cc: cc}
cc: cc,
activeRes: make(map[uint32]*clientStream),
}
defer rl.cleanup() defer rl.cleanup()
cc.readerErr = rl.run() cc.readerErr = rl.run()
if ce, ok := cc.readerErr.(ConnectionError); ok { if ce, ok := cc.readerErr.(ConnectionError); ok {
@ -1319,10 +1458,8 @@ func (rl *clientConnReadLoop) cleanup() {
} else if err == io.EOF { } else if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
for _, cs := range rl.activeRes {
cs.bufPipe.CloseWithError(err)
}
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.bufPipe.CloseWithError(err) // no-op if already closed
select { select {
case cs.resc <- resAndError{err: err}: case cs.resc <- resAndError{err: err}:
default: default:
@ -1345,8 +1482,9 @@ func (rl *clientConnReadLoop) run() error {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
} }
if se, ok := err.(StreamError); ok { if se, ok := err.(StreamError); ok {
if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil { if cs := cc.streamByID(se.StreamID, false); cs != nil {
cs.cc.writeStreamReset(cs.ID, se.Code, err) cs.cc.writeStreamReset(cs.ID, se.Code, err)
cs.cc.forgetStreamID(cs.ID)
if se.Cause == nil { if se.Cause == nil {
se.Cause = cc.fr.errDetail se.Cause = cc.fr.errDetail
} }
@ -1399,7 +1537,7 @@ func (rl *clientConnReadLoop) run() error {
} }
return err return err
} }
if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { if rl.closeWhenIdle && gotReply && maybeIdle {
cc.closeIfIdle() cc.closeIfIdle()
} }
} }
@ -1407,13 +1545,31 @@ func (rl *clientConnReadLoop) run() error {
func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
cc := rl.cc cc := rl.cc
cs := cc.streamByID(f.StreamID, f.StreamEnded()) cs := cc.streamByID(f.StreamID, false)
if cs == nil { if cs == nil {
// We'd get here if we canceled a request while the // We'd get here if we canceled a request while the
// server had its response still in flight. So if this // server had its response still in flight. So if this
// was just something we canceled, ignore it. // was just something we canceled, ignore it.
return nil return nil
} }
if f.StreamEnded() {
// Issue 20521: If the stream has ended, streamByID() causes
// clientStream.done to be closed, which causes the request's bodyWriter
// to be closed with an errStreamClosed, which may be received by
// clientConn.RoundTrip before the result of processing these headers.
// Deferring stream closure allows the header processing to occur first.
// clientConn.RoundTrip may still receive the bodyWriter error first, but
// the fix for issue 16102 prioritises any response.
//
// Issue 22413: If there is no request body, we should close the
// stream before writing to cs.resc so that the stream is closed
// immediately once RoundTrip returns.
if cs.req.Body != nil {
defer cc.forgetStreamID(f.StreamID)
} else {
cc.forgetStreamID(f.StreamID)
}
}
if !cs.firstByte { if !cs.firstByte {
if cs.trace != nil { if cs.trace != nil {
// TODO(bradfitz): move first response byte earlier, // TODO(bradfitz): move first response byte earlier,
@ -1437,6 +1593,7 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
} }
// Any other error type is a stream error. // Any other error type is a stream error.
cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err) cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err)
cc.forgetStreamID(cs.ID)
cs.resc <- resAndError{err: err} cs.resc <- resAndError{err: err}
return nil // return nil from process* funcs to keep conn alive return nil // return nil from process* funcs to keep conn alive
} }
@ -1444,9 +1601,6 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
// (nil, nil) special case. See handleResponse docs. // (nil, nil) special case. See handleResponse docs.
return nil return nil
} }
if res.Body != noBody {
rl.activeRes[cs.ID] = cs
}
cs.resTrailer = &res.Trailer cs.resTrailer = &res.Trailer
cs.resc <- resAndError{res: res} cs.resc <- resAndError{res: res}
return nil return nil
@ -1466,11 +1620,11 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
status := f.PseudoValue("status") status := f.PseudoValue("status")
if status == "" { if status == "" {
return nil, errors.New("missing status pseudo header") return nil, errors.New("malformed response from server: missing status pseudo header")
} }
statusCode, err := strconv.Atoi(status) statusCode, err := strconv.Atoi(status)
if err != nil { if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header") return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header")
} }
if statusCode == 100 { if statusCode == 100 {
@ -1668,6 +1822,7 @@ func (b transportResponseBody) Close() error {
} }
cs.bufPipe.BreakWithError(errClosedResponseBody) cs.bufPipe.BreakWithError(errClosedResponseBody)
cc.forgetStreamID(cs.ID)
return nil return nil
} }
@ -1702,7 +1857,23 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
} }
return nil return nil
} }
if !cs.firstByte {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
Code: ErrCodeProtocol,
})
return nil
}
if f.Length > 0 { if f.Length > 0 {
if cs.req.Method == "HEAD" && len(data) > 0 {
cc.logf("protocol error: received DATA on a HEAD request")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
Code: ErrCodeProtocol,
})
return nil
}
// Check connection-level flow control. // Check connection-level flow control.
cc.mu.Lock() cc.mu.Lock()
if cs.inflow.available() >= int32(f.Length) { if cs.inflow.available() >= int32(f.Length) {
@ -1713,16 +1884,27 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
} }
// Return any padded flow control now, since we won't // Return any padded flow control now, since we won't
// refund it later on body reads. // refund it later on body reads.
if pad := int32(f.Length) - int32(len(data)); pad > 0 { var refund int
cs.inflow.add(pad) if pad := int(f.Length) - len(data); pad > 0 {
cc.inflow.add(pad) refund += pad
}
// Return len(data) now if the stream is already closed,
// since data will never be read.
didReset := cs.didReset
if didReset {
refund += len(data)
}
if refund > 0 {
cc.inflow.add(int32(refund))
cc.wmu.Lock() cc.wmu.Lock()
cc.fr.WriteWindowUpdate(0, uint32(pad)) cc.fr.WriteWindowUpdate(0, uint32(refund))
cc.fr.WriteWindowUpdate(cs.ID, uint32(pad)) if !didReset {
cs.inflow.add(int32(refund))
cc.fr.WriteWindowUpdate(cs.ID, uint32(refund))
}
cc.bw.Flush() cc.bw.Flush()
cc.wmu.Unlock() cc.wmu.Unlock()
} }
didReset := cs.didReset
cc.mu.Unlock() cc.mu.Unlock()
if len(data) > 0 && !didReset { if len(data) > 0 && !didReset {
@ -1753,11 +1935,10 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
err = io.EOF err = io.EOF
code = cs.copyTrailers code = cs.copyTrailers
} }
cs.bufPipe.closeWithErrorAndCode(err, code)
delete(rl.activeRes, cs.ID)
if isConnectionCloseRequest(cs.req) { if isConnectionCloseRequest(cs.req) {
rl.closeWhenIdle = true rl.closeWhenIdle = true
} }
cs.bufPipe.closeWithErrorAndCode(err, code)
select { select {
case cs.resc <- resAndError{err: err}: case cs.resc <- resAndError{err: err}:
@ -1805,6 +1986,8 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error {
cc.maxFrameSize = s.Val cc.maxFrameSize = s.Val
case SettingMaxConcurrentStreams: case SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val cc.maxConcurrentStreams = s.Val
case SettingMaxHeaderListSize:
cc.peerMaxHeaderListSize = uint64(s.Val)
case SettingInitialWindowSize: case SettingInitialWindowSize:
// Values above the maximum flow-control // Values above the maximum flow-control
// window size of 2^31-1 MUST be treated as a // window size of 2^31-1 MUST be treated as a
@ -1882,7 +2065,6 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
cs.bufPipe.CloseWithError(err) cs.bufPipe.CloseWithError(err)
cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
} }
delete(rl.activeRes, cs.ID)
return nil return nil
} }
@ -1971,6 +2153,7 @@ func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error)
var ( var (
errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit")
errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers")
) )

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,6 @@ import (
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"time"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
@ -90,11 +89,7 @@ type writeGoAway struct {
func (p *writeGoAway) writeFrame(ctx writeContext) error { func (p *writeGoAway) writeFrame(ctx writeContext) error {
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
if p.code != 0 { ctx.Flush() // ignore error: we're hanging up on them anyway
ctx.Flush() // ignore error: we're hanging up on them anyway
time.Sleep(50 * time.Millisecond)
ctx.CloseConn()
}
return err return err
} }

274
vendor/golang.org/x/net/icmp/diag_test.go generated vendored Normal file
View File

@ -0,0 +1,274 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package icmp_test
import (
"errors"
"fmt"
"net"
"os"
"runtime"
"sync"
"testing"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/internal/iana"
"golang.org/x/net/internal/nettest"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type diagTest struct {
network, address string
protocol int
m icmp.Message
}
func TestDiag(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
t.Run("Ping/NonPrivileged", func(t *testing.T) {
switch runtime.GOOS {
case "darwin":
case "linux":
t.Log("you may need to adjust the net.ipv4.ping_group_range kernel state")
default:
t.Logf("not supported on %s", runtime.GOOS)
return
}
for i, dt := range []diagTest{
{
"udp4", "0.0.0.0", iana.ProtocolICMP,
icmp.Message{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
{
"udp6", "::", iana.ProtocolIPv6ICMP,
icmp.Message{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
} {
if err := doDiag(dt, i); err != nil {
t.Error(err)
}
}
})
t.Run("Ping/Privileged", func(t *testing.T) {
if m, ok := nettest.SupportsRawIPSocket(); !ok {
t.Skip(m)
}
for i, dt := range []diagTest{
{
"ip4:icmp", "0.0.0.0", iana.ProtocolICMP,
icmp.Message{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
{
"ip6:ipv6-icmp", "::", iana.ProtocolIPv6ICMP,
icmp.Message{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
} {
if err := doDiag(dt, i); err != nil {
t.Error(err)
}
}
})
t.Run("Probe/Privileged", func(t *testing.T) {
if m, ok := nettest.SupportsRawIPSocket(); !ok {
t.Skip(m)
}
for i, dt := range []diagTest{
{
"ip4:icmp", "0.0.0.0", iana.ProtocolICMP,
icmp.Message{
Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: os.Getpid() & 0xffff,
Local: true,
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Class: 3, Type: 1,
Name: "doesnotexist",
},
},
},
},
},
{
"ip6:ipv6-icmp", "::", iana.ProtocolIPv6ICMP,
icmp.Message{
Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: os.Getpid() & 0xffff,
Local: true,
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Class: 3, Type: 1,
Name: "doesnotexist",
},
},
},
},
},
} {
if err := doDiag(dt, i); err != nil {
t.Error(err)
}
}
})
}
func doDiag(dt diagTest, seq int) error {
c, err := icmp.ListenPacket(dt.network, dt.address)
if err != nil {
return err
}
defer c.Close()
dst, err := googleAddr(c, dt.protocol)
if err != nil {
return err
}
if dt.network != "udp6" && dt.protocol == iana.ProtocolIPv6ICMP {
var f ipv6.ICMPFilter
f.SetAll(true)
f.Accept(ipv6.ICMPTypeDestinationUnreachable)
f.Accept(ipv6.ICMPTypePacketTooBig)
f.Accept(ipv6.ICMPTypeTimeExceeded)
f.Accept(ipv6.ICMPTypeParameterProblem)
f.Accept(ipv6.ICMPTypeEchoReply)
f.Accept(ipv6.ICMPTypeExtendedEchoReply)
if err := c.IPv6PacketConn().SetICMPFilter(&f); err != nil {
return err
}
}
switch m := dt.m.Body.(type) {
case *icmp.Echo:
m.Seq = 1 << uint(seq)
case *icmp.ExtendedEchoRequest:
m.Seq = 1 << uint(seq)
}
wb, err := dt.m.Marshal(nil)
if err != nil {
return err
}
if n, err := c.WriteTo(wb, dst); err != nil {
return err
} else if n != len(wb) {
return fmt.Errorf("got %v; want %v", n, len(wb))
}
rb := make([]byte, 1500)
if err := c.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
return err
}
n, peer, err := c.ReadFrom(rb)
if err != nil {
return err
}
rm, err := icmp.ParseMessage(dt.protocol, rb[:n])
if err != nil {
return err
}
switch {
case dt.m.Type == ipv4.ICMPTypeEcho && rm.Type == ipv4.ICMPTypeEchoReply:
fallthrough
case dt.m.Type == ipv6.ICMPTypeEchoRequest && rm.Type == ipv6.ICMPTypeEchoReply:
fallthrough
case dt.m.Type == ipv4.ICMPTypeExtendedEchoRequest && rm.Type == ipv4.ICMPTypeExtendedEchoReply:
fallthrough
case dt.m.Type == ipv6.ICMPTypeExtendedEchoRequest && rm.Type == ipv6.ICMPTypeExtendedEchoReply:
return nil
default:
return fmt.Errorf("got %+v from %v; want echo reply or extended echo reply", rm, peer)
}
}
func googleAddr(c *icmp.PacketConn, protocol int) (net.Addr, error) {
host := "ipv4.google.com"
if protocol == iana.ProtocolIPv6ICMP {
host = "ipv6.google.com"
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
netaddr := func(ip net.IP) (net.Addr, error) {
switch c.LocalAddr().(type) {
case *net.UDPAddr:
return &net.UDPAddr{IP: ip}, nil
case *net.IPAddr:
return &net.IPAddr{IP: ip}, nil
default:
return nil, errors.New("neither UDPAddr nor IPAddr")
}
}
if len(ips) > 0 {
return netaddr(ips[0])
}
return nil, errors.New("no A or AAAA record")
}
func TestConcurrentNonPrivilegedListenPacket(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
switch runtime.GOOS {
case "darwin":
case "linux":
t.Log("you may need to adjust the net.ipv4.ping_group_range kernel state")
default:
t.Skipf("not supported on %s", runtime.GOOS)
}
network, address := "udp4", "127.0.0.1"
if !nettest.SupportsIPv4() {
network, address = "udp6", "::1"
}
const N = 1000
var wg sync.WaitGroup
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer wg.Done()
c, err := icmp.ListenPacket(network, address)
if err != nil {
t.Error(err)
return
}
c.Close()
}()
}
wg.Wait()
}

View File

@ -16,24 +16,24 @@ func (p *DstUnreach) Len(proto int) int {
if p == nil { if p == nil {
return 0 return 0
} }
l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
return 4 + l return 4 + l
} }
// Marshal implements the Marshal method of MessageBody interface. // Marshal implements the Marshal method of MessageBody interface.
func (p *DstUnreach) Marshal(proto int) ([]byte, error) { func (p *DstUnreach) Marshal(proto int) ([]byte, error) {
return marshalMultipartMessageBody(proto, p.Data, p.Extensions) return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
} }
// parseDstUnreach parses b as an ICMP destination unreachable message // parseDstUnreach parses b as an ICMP destination unreachable message
// body. // body.
func parseDstUnreach(proto int, b []byte) (MessageBody, error) { func parseDstUnreach(proto int, typ Type, b []byte) (MessageBody, error) {
if len(b) < 4 { if len(b) < 4 {
return nil, errMessageTooShort return nil, errMessageTooShort
} }
p := &DstUnreach{} p := &DstUnreach{}
var err error var err error
p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) p.Data, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }

114
vendor/golang.org/x/net/icmp/echo.go generated vendored
View File

@ -31,7 +31,7 @@ func (p *Echo) Marshal(proto int) ([]byte, error) {
} }
// parseEcho parses b as an ICMP echo request or reply message body. // parseEcho parses b as an ICMP echo request or reply message body.
func parseEcho(proto int, b []byte) (MessageBody, error) { func parseEcho(proto int, _ Type, b []byte) (MessageBody, error) {
bodyLen := len(b) bodyLen := len(b)
if bodyLen < 4 { if bodyLen < 4 {
return nil, errMessageTooShort return nil, errMessageTooShort
@ -43,3 +43,115 @@ func parseEcho(proto int, b []byte) (MessageBody, error) {
} }
return p, nil return p, nil
} }
// An ExtendedEchoRequest represents an ICMP extended echo request
// message body.
type ExtendedEchoRequest struct {
ID int // identifier
Seq int // sequence number
Local bool // must be true when identifying by name or index
Extensions []Extension // extensions
}
// Len implements the Len method of MessageBody interface.
func (p *ExtendedEchoRequest) Len(proto int) int {
if p == nil {
return 0
}
l, _ := multipartMessageBodyDataLen(proto, false, nil, p.Extensions)
return 4 + l
}
// Marshal implements the Marshal method of MessageBody interface.
func (p *ExtendedEchoRequest) Marshal(proto int) ([]byte, error) {
b, err := marshalMultipartMessageBody(proto, false, nil, p.Extensions)
if err != nil {
return nil, err
}
bb := make([]byte, 4)
binary.BigEndian.PutUint16(bb[:2], uint16(p.ID))
bb[2] = byte(p.Seq)
if p.Local {
bb[3] |= 0x01
}
bb = append(bb, b...)
return bb, nil
}
// parseExtendedEchoRequest parses b as an ICMP extended echo request
// message body.
func parseExtendedEchoRequest(proto int, typ Type, b []byte) (MessageBody, error) {
if len(b) < 4+4 {
return nil, errMessageTooShort
}
p := &ExtendedEchoRequest{ID: int(binary.BigEndian.Uint16(b[:2])), Seq: int(b[2])}
if b[3]&0x01 != 0 {
p.Local = true
}
var err error
_, p.Extensions, err = parseMultipartMessageBody(proto, typ, b[4:])
if err != nil {
return nil, err
}
return p, nil
}
// An ExtendedEchoReply represents an ICMP extended echo reply message
// body.
type ExtendedEchoReply struct {
ID int // identifier
Seq int // sequence number
State int // 3-bit state working together with Message.Code
Active bool // probed interface is active
IPv4 bool // probed interface runs IPv4
IPv6 bool // probed interface runs IPv6
}
// Len implements the Len method of MessageBody interface.
func (p *ExtendedEchoReply) Len(proto int) int {
if p == nil {
return 0
}
return 4
}
// Marshal implements the Marshal method of MessageBody interface.
func (p *ExtendedEchoReply) Marshal(proto int) ([]byte, error) {
b := make([]byte, 4)
binary.BigEndian.PutUint16(b[:2], uint16(p.ID))
b[2] = byte(p.Seq)
b[3] = byte(p.State<<5) & 0xe0
if p.Active {
b[3] |= 0x04
}
if p.IPv4 {
b[3] |= 0x02
}
if p.IPv6 {
b[3] |= 0x01
}
return b, nil
}
// parseExtendedEchoReply parses b as an ICMP extended echo reply
// message body.
func parseExtendedEchoReply(proto int, _ Type, b []byte) (MessageBody, error) {
if len(b) < 4 {
return nil, errMessageTooShort
}
p := &ExtendedEchoReply{
ID: int(binary.BigEndian.Uint16(b[:2])),
Seq: int(b[2]),
State: int(b[3]) >> 5,
}
if b[3]&0x04 != 0 {
p.Active = true
}
if b[3]&0x02 != 0 {
p.IPv4 = true
}
if b[3]&0x01 != 0 {
p.IPv6 = true
}
return p, nil
}

View File

@ -4,7 +4,12 @@
package icmp package icmp
import "encoding/binary" import (
"encoding/binary"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// An Extension represents an ICMP extension. // An Extension represents an ICMP extension.
type Extension interface { type Extension interface {
@ -38,7 +43,7 @@ func validExtensionHeader(b []byte) bool {
// It will return a list of ICMP extensions and an adjusted length // It will return a list of ICMP extensions and an adjusted length
// attribute that represents the length of the padded original // attribute that represents the length of the padded original
// datagram field. Otherwise, it returns an error. // datagram field. Otherwise, it returns an error.
func parseExtensions(b []byte, l int) ([]Extension, int, error) { func parseExtensions(typ Type, b []byte, l int) ([]Extension, int, error) {
// Still a lot of non-RFC 4884 compliant implementations are // Still a lot of non-RFC 4884 compliant implementations are
// out there. Set the length attribute l to 128 when it looks // out there. Set the length attribute l to 128 when it looks
// inappropriate for backwards compatibility. // inappropriate for backwards compatibility.
@ -48,20 +53,28 @@ func parseExtensions(b []byte, l int) ([]Extension, int, error) {
// header. // header.
// //
// See RFC 4884 for further information. // See RFC 4884 for further information.
if 128 > l || l+8 > len(b) { switch typ {
l = 128 case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
} if len(b) < 8 || !validExtensionHeader(b) {
if l+8 > len(b) {
return nil, -1, errNoExtension
}
if !validExtensionHeader(b[l:]) {
if l == 128 {
return nil, -1, errNoExtension return nil, -1, errNoExtension
} }
l = 128 l = 0
if !validExtensionHeader(b[l:]) { default:
if 128 > l || l+8 > len(b) {
l = 128
}
if l+8 > len(b) {
return nil, -1, errNoExtension return nil, -1, errNoExtension
} }
if !validExtensionHeader(b[l:]) {
if l == 128 {
return nil, -1, errNoExtension
}
l = 128
if !validExtensionHeader(b[l:]) {
return nil, -1, errNoExtension
}
}
} }
var exts []Extension var exts []Extension
for b = b[l+4:]; len(b) >= 4; { for b = b[l+4:]; len(b) >= 4; {
@ -82,6 +95,12 @@ func parseExtensions(b []byte, l int) ([]Extension, int, error) {
return nil, -1, err return nil, -1, err
} }
exts = append(exts, ext) exts = append(exts, ext)
case classInterfaceIdent:
ext, err := parseInterfaceIdent(b[:ol])
if err != nil {
return nil, -1, err
}
exts = append(exts, ext)
} }
b = b[ol:] b = b[ol:]
} }

View File

@ -5,253 +5,327 @@
package icmp package icmp
import ( import (
"fmt"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"golang.org/x/net/internal/iana" "golang.org/x/net/internal/iana"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
var marshalAndParseExtensionTests = []struct {
proto int
hdr []byte
obj []byte
exts []Extension
}{
// MPLS label stack with no label
{
proto: iana.ProtocolICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x04, 0x01, 0x01,
},
exts: []Extension{
&MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
},
},
},
// MPLS label stack with a single label
{
proto: iana.ProtocolIPv6ICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x08, 0x01, 0x01,
0x03, 0xe8, 0xe9, 0xff,
},
exts: []Extension{
&MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
},
},
// MPLS label stack with multiple labels
{
proto: iana.ProtocolICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x0c, 0x01, 0x01,
0x03, 0xe8, 0xde, 0xfe,
0x03, 0xe8, 0xe1, 0xff,
},
exts: []Extension{
&MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
{
Label: 16013,
TC: 0x7,
S: false,
TTL: 254,
},
{
Label: 16014,
TC: 0,
S: true,
TTL: 255,
},
},
},
},
},
// Interface information with no attribute
{
proto: iana.ProtocolICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x04, 0x02, 0x00,
},
exts: []Extension{
&InterfaceInfo{
Class: classInterfaceInfo,
},
},
},
// Interface information with ifIndex and name
{
proto: iana.ProtocolICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x10, 0x02, 0x0a,
0x00, 0x00, 0x00, 0x10,
0x08, byte('e'), byte('n'), byte('1'),
byte('0'), byte('1'), 0x00, 0x00,
},
exts: []Extension{
&InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0a,
Interface: &net.Interface{
Index: 16,
Name: "en101",
},
},
},
},
// Interface information with ifIndex, IPAddr, name and MTU
{
proto: iana.ProtocolIPv6ICMP,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x28, 0x02, 0x0f,
0x00, 0x00, 0x00, 0x0f,
0x00, 0x02, 0x00, 0x00,
0xfe, 0x80, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
0x08, byte('e'), byte('n'), byte('1'),
byte('0'), byte('1'), 0x00, 0x00,
0x00, 0x00, 0x20, 0x00,
},
exts: []Extension{
&InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en101",
},
},
},
},
}
func TestMarshalAndParseExtension(t *testing.T) { func TestMarshalAndParseExtension(t *testing.T) {
for i, tt := range marshalAndParseExtensionTests { fn := func(t *testing.T, proto int, typ Type, hdr, obj []byte, te Extension) error {
for j, ext := range tt.exts { b, err := te.Marshal(proto)
var err error if err != nil {
var b []byte return err
switch ext := ext.(type) { }
case *MPLSLabelStack: if !reflect.DeepEqual(b, obj) {
b, err = ext.Marshal(tt.proto) return fmt.Errorf("got %#v; want %#v", b, obj)
if err != nil { }
t.Errorf("#%v/%v: %v", i, j, err) switch typ {
case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
exts, l, err := parseExtensions(typ, append(hdr, obj...), 0)
if err != nil {
return err
}
if l != 0 {
return fmt.Errorf("got %d; want 0", l)
}
if !reflect.DeepEqual(exts, []Extension{te}) {
return fmt.Errorf("got %#v; want %#v", exts[0], te)
}
default:
for i, wire := range []struct {
data []byte // original datagram
inlattr int // length of padded original datagram, a hint
outlattr int // length of padded original datagram, a want
err error
}{
{nil, 0, -1, errNoExtension},
{make([]byte, 127), 128, -1, errNoExtension},
{make([]byte, 128), 127, -1, errNoExtension},
{make([]byte, 128), 128, -1, errNoExtension},
{make([]byte, 128), 129, -1, errNoExtension},
{append(make([]byte, 128), append(hdr, obj...)...), 127, 128, nil},
{append(make([]byte, 128), append(hdr, obj...)...), 128, 128, nil},
{append(make([]byte, 128), append(hdr, obj...)...), 129, 128, nil},
{append(make([]byte, 512), append(hdr, obj...)...), 511, -1, errNoExtension},
{append(make([]byte, 512), append(hdr, obj...)...), 512, 512, nil},
{append(make([]byte, 512), append(hdr, obj...)...), 513, -1, errNoExtension},
} {
exts, l, err := parseExtensions(typ, wire.data, wire.inlattr)
if err != wire.err {
return fmt.Errorf("#%d: got %v; want %v", i, err, wire.err)
}
if wire.err != nil {
continue continue
} }
case *InterfaceInfo: if l != wire.outlattr {
b, err = ext.Marshal(tt.proto) return fmt.Errorf("#%d: got %d; want %d", i, l, wire.outlattr)
if err != nil { }
t.Errorf("#%v/%v: %v", i, j, err) if !reflect.DeepEqual(exts, []Extension{te}) {
continue return fmt.Errorf("#%d: got %#v; want %#v", i, exts[0], te)
} }
} }
if !reflect.DeepEqual(b, tt.obj) {
t.Errorf("#%v/%v: got %#v; want %#v", i, j, b, tt.obj)
continue
}
}
for j, wire := range []struct {
data []byte // original datagram
inlattr int // length of padded original datagram, a hint
outlattr int // length of padded original datagram, a want
err error
}{
{nil, 0, -1, errNoExtension},
{make([]byte, 127), 128, -1, errNoExtension},
{make([]byte, 128), 127, -1, errNoExtension},
{make([]byte, 128), 128, -1, errNoExtension},
{make([]byte, 128), 129, -1, errNoExtension},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 127, 128, nil},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 128, 128, nil},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 129, 128, nil},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 511, -1, errNoExtension},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 512, 512, nil},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 513, -1, errNoExtension},
} {
exts, l, err := parseExtensions(wire.data, wire.inlattr)
if err != wire.err {
t.Errorf("#%v/%v: got %v; want %v", i, j, err, wire.err)
continue
}
if wire.err != nil {
continue
}
if l != wire.outlattr {
t.Errorf("#%v/%v: got %v; want %v", i, j, l, wire.outlattr)
}
if !reflect.DeepEqual(exts, tt.exts) {
for j, ext := range exts {
switch ext := ext.(type) {
case *MPLSLabelStack:
want := tt.exts[j].(*MPLSLabelStack)
t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want)
case *InterfaceInfo:
want := tt.exts[j].(*InterfaceInfo)
t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want)
}
}
continue
}
} }
return nil
} }
}
var parseInterfaceNameTests = []struct { t.Run("MPLSLabelStack", func(t *testing.T) {
b []byte for _, et := range []struct {
error proto int
}{ typ Type
{[]byte{0, 'e', 'n', '0'}, errInvalidExtension}, hdr []byte
{[]byte{4, 'e', 'n', '0'}, nil}, obj []byte
{[]byte{7, 'e', 'n', '0', 0xff, 0xff, 0xff, 0xff}, errInvalidExtension}, ext Extension
{[]byte{8, 'e', 'n', '0', 0xff, 0xff, 0xff}, errMessageTooShort}, }{
// MPLS label stack with no label
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x04, 0x01, 0x01,
},
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
},
},
// MPLS label stack with a single label
{
proto: iana.ProtocolIPv6ICMP,
typ: ipv6.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x08, 0x01, 0x01,
0x03, 0xe8, 0xe9, 0xff,
},
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
},
// MPLS label stack with multiple labels
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x0c, 0x01, 0x01,
0x03, 0xe8, 0xde, 0xfe,
0x03, 0xe8, 0xe1, 0xff,
},
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
{
Label: 16013,
TC: 0x7,
S: false,
TTL: 254,
},
{
Label: 16014,
TC: 0,
S: true,
TTL: 255,
},
},
},
},
} {
if err := fn(t, et.proto, et.typ, et.hdr, et.obj, et.ext); err != nil {
t.Error(err)
}
}
})
t.Run("InterfaceInfo", func(t *testing.T) {
for _, et := range []struct {
proto int
typ Type
hdr []byte
obj []byte
ext Extension
}{
// Interface information with no attribute
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x04, 0x02, 0x00,
},
ext: &InterfaceInfo{
Class: classInterfaceInfo,
},
},
// Interface information with ifIndex and name
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x10, 0x02, 0x0a,
0x00, 0x00, 0x00, 0x10,
0x08, byte('e'), byte('n'), byte('1'),
byte('0'), byte('1'), 0x00, 0x00,
},
ext: &InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0a,
Interface: &net.Interface{
Index: 16,
Name: "en101",
},
},
},
// Interface information with ifIndex, IPAddr, name and MTU
{
proto: iana.ProtocolIPv6ICMP,
typ: ipv6.ICMPTypeDestinationUnreachable,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x28, 0x02, 0x0f,
0x00, 0x00, 0x00, 0x0f,
0x00, 0x02, 0x00, 0x00,
0xfe, 0x80, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
0x08, byte('e'), byte('n'), byte('1'),
byte('0'), byte('1'), 0x00, 0x00,
0x00, 0x00, 0x20, 0x00,
},
ext: &InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en101",
},
},
},
} {
if err := fn(t, et.proto, et.typ, et.hdr, et.obj, et.ext); err != nil {
t.Error(err)
}
}
})
t.Run("InterfaceIdent", func(t *testing.T) {
for _, et := range []struct {
proto int
typ Type
hdr []byte
obj []byte
ext Extension
}{
// Interface identification by name
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeExtendedEchoRequest,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x0c, 0x03, 0x01,
byte('e'), byte('n'), byte('1'), byte('0'),
byte('1'), 0x00, 0x00, 0x00,
},
ext: &InterfaceIdent{
Class: classInterfaceIdent,
Type: typeInterfaceByName,
Name: "en101",
},
},
// Interface identification by index
{
proto: iana.ProtocolIPv6ICMP,
typ: ipv6.ICMPTypeExtendedEchoRequest,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x0c, 0x03, 0x02,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x03, 0x8f,
},
ext: &InterfaceIdent{
Class: classInterfaceIdent,
Type: typeInterfaceByIndex,
Index: 911,
},
},
// Interface identification by address
{
proto: iana.ProtocolICMP,
typ: ipv4.ICMPTypeExtendedEchoRequest,
hdr: []byte{
0x20, 0x00, 0x00, 0x00,
},
obj: []byte{
0x00, 0x10, 0x03, 0x03,
byte(iana.AddrFamily48bitMAC >> 8), byte(iana.AddrFamily48bitMAC & 0x0f), 0x06, 0x00,
0x01, 0x23, 0x45, 0x67,
0x89, 0xab, 0x00, 0x00,
},
ext: &InterfaceIdent{
Class: classInterfaceIdent,
Type: typeInterfaceByAddress,
AFI: iana.AddrFamily48bitMAC,
Addr: []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab},
},
},
} {
if err := fn(t, et.proto, et.typ, et.hdr, et.obj, et.ext); err != nil {
t.Error(err)
}
}
})
} }
func TestParseInterfaceName(t *testing.T) { func TestParseInterfaceName(t *testing.T) {
ifi := InterfaceInfo{Interface: &net.Interface{}} ifi := InterfaceInfo{Interface: &net.Interface{}}
for i, tt := range parseInterfaceNameTests { for i, tt := range []struct {
b []byte
error
}{
{[]byte{0, 'e', 'n', '0'}, errInvalidExtension},
{[]byte{4, 'e', 'n', '0'}, nil},
{[]byte{7, 'e', 'n', '0', 0xff, 0xff, 0xff, 0xff}, errInvalidExtension},
{[]byte{8, 'e', 'n', '0', 0xff, 0xff, 0xff}, errMessageTooShort},
} {
if _, err := ifi.parseName(tt.b); err != tt.error { if _, err := ifi.parseName(tt.b); err != tt.error {
t.Errorf("#%d: got %v; want %v", i, err, tt.error) t.Errorf("#%d: got %v; want %v", i, err, tt.error)
} }

View File

@ -1,27 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package icmp
import (
"encoding/binary"
"unsafe"
)
var (
// See http://www.freebsd.org/doc/en/books/porters-handbook/freebsd-versions.html.
freebsdVersion uint32
nativeEndian binary.ByteOrder
)
func init() {
i := uint32(1)
b := (*[4]byte)(unsafe.Pointer(&i))
if b[0] == 1 {
nativeEndian = binary.LittleEndian
} else {
nativeEndian = binary.BigEndian
}
}

View File

@ -14,9 +14,6 @@ import (
const ( const (
classInterfaceInfo = 2 classInterfaceInfo = 2
afiIPv4 = 1
afiIPv6 = 2
) )
const ( const (
@ -127,11 +124,11 @@ func (ifi *InterfaceInfo) parseIfIndex(b []byte) ([]byte, error) {
func (ifi *InterfaceInfo) marshalIPAddr(proto int, b []byte) []byte { func (ifi *InterfaceInfo) marshalIPAddr(proto int, b []byte) []byte {
switch proto { switch proto {
case iana.ProtocolICMP: case iana.ProtocolICMP:
binary.BigEndian.PutUint16(b[:2], uint16(afiIPv4)) binary.BigEndian.PutUint16(b[:2], uint16(iana.AddrFamilyIPv4))
copy(b[4:4+net.IPv4len], ifi.Addr.IP.To4()) copy(b[4:4+net.IPv4len], ifi.Addr.IP.To4())
b = b[4+net.IPv4len:] b = b[4+net.IPv4len:]
case iana.ProtocolIPv6ICMP: case iana.ProtocolIPv6ICMP:
binary.BigEndian.PutUint16(b[:2], uint16(afiIPv6)) binary.BigEndian.PutUint16(b[:2], uint16(iana.AddrFamilyIPv6))
copy(b[4:4+net.IPv6len], ifi.Addr.IP.To16()) copy(b[4:4+net.IPv6len], ifi.Addr.IP.To16())
b = b[4+net.IPv6len:] b = b[4+net.IPv6len:]
} }
@ -145,14 +142,14 @@ func (ifi *InterfaceInfo) parseIPAddr(b []byte) ([]byte, error) {
afi := int(binary.BigEndian.Uint16(b[:2])) afi := int(binary.BigEndian.Uint16(b[:2]))
b = b[4:] b = b[4:]
switch afi { switch afi {
case afiIPv4: case iana.AddrFamilyIPv4:
if len(b) < net.IPv4len { if len(b) < net.IPv4len {
return nil, errMessageTooShort return nil, errMessageTooShort
} }
ifi.Addr.IP = make(net.IP, net.IPv4len) ifi.Addr.IP = make(net.IP, net.IPv4len)
copy(ifi.Addr.IP, b[:net.IPv4len]) copy(ifi.Addr.IP, b[:net.IPv4len])
b = b[net.IPv4len:] b = b[net.IPv4len:]
case afiIPv6: case iana.AddrFamilyIPv6:
if len(b) < net.IPv6len { if len(b) < net.IPv6len {
return nil, errMessageTooShort return nil, errMessageTooShort
} }
@ -234,3 +231,92 @@ func parseInterfaceInfo(b []byte) (Extension, error) {
} }
return ifi, nil return ifi, nil
} }
const (
classInterfaceIdent = 3
typeInterfaceByName = 1
typeInterfaceByIndex = 2
typeInterfaceByAddress = 3
)
// An InterfaceIdent represents interface identification.
type InterfaceIdent struct {
Class int // extension object class number
Type int // extension object sub-type
Name string // interface name
Index int // interface index
AFI int // address family identifier; see address family numbers in IANA registry
Addr []byte // address
}
// Len implements the Len method of Extension interface.
func (ifi *InterfaceIdent) Len(_ int) int {
switch ifi.Type {
case typeInterfaceByName:
l := len(ifi.Name)
if l > 255 {
l = 255
}
return 4 + (l+3)&^3
case typeInterfaceByIndex:
return 4 + 8
case typeInterfaceByAddress:
return 4 + 4 + (len(ifi.Addr)+3)&^3
default:
return 4
}
}
// Marshal implements the Marshal method of Extension interface.
func (ifi *InterfaceIdent) Marshal(proto int) ([]byte, error) {
b := make([]byte, ifi.Len(proto))
if err := ifi.marshal(proto, b); err != nil {
return nil, err
}
return b, nil
}
func (ifi *InterfaceIdent) marshal(proto int, b []byte) error {
l := ifi.Len(proto)
binary.BigEndian.PutUint16(b[:2], uint16(l))
b[2], b[3] = classInterfaceIdent, byte(ifi.Type)
switch ifi.Type {
case typeInterfaceByName:
copy(b[4:], ifi.Name)
case typeInterfaceByIndex:
binary.BigEndian.PutUint64(b[4:4+8], uint64(ifi.Index))
case typeInterfaceByAddress:
binary.BigEndian.PutUint16(b[4:4+2], uint16(ifi.AFI))
b[4+2] = byte(len(ifi.Addr))
copy(b[4+4:], ifi.Addr)
}
return nil
}
func parseInterfaceIdent(b []byte) (Extension, error) {
ifi := &InterfaceIdent{
Class: int(b[2]),
Type: int(b[3]),
}
switch ifi.Type {
case typeInterfaceByName:
ifi.Name = strings.Trim(string(b[4:]), string(0))
case typeInterfaceByIndex:
if len(b[4:]) < 8 {
return nil, errInvalidExtension
}
ifi.Index = int(binary.BigEndian.Uint64(b[4 : 4+8]))
case typeInterfaceByAddress:
if len(b[4:]) < 4 {
return nil, errInvalidExtension
}
ifi.AFI = int(binary.BigEndian.Uint16(b[4 : 4+2]))
l := int(b[4+2])
if len(b[4+4:]) < l {
return nil, errInvalidExtension
}
ifi.Addr = make([]byte, l)
copy(ifi.Addr, b[4+4:])
}
return ifi, nil
}

View File

@ -9,9 +9,14 @@ import (
"net" "net"
"runtime" "runtime"
"golang.org/x/net/internal/socket"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
// freebsdVersion is set in sys_freebsd.go.
// See http://www.freebsd.org/doc/en/books/porters-handbook/freebsd-versions.html.
var freebsdVersion uint32
// ParseIPv4Header parses b as an IPv4 header of ICMP error message // ParseIPv4Header parses b as an IPv4 header of ICMP error message
// invoking packet, which is contained in ICMP error message. // invoking packet, which is contained in ICMP error message.
func ParseIPv4Header(b []byte) (*ipv4.Header, error) { func ParseIPv4Header(b []byte) (*ipv4.Header, error) {
@ -36,12 +41,12 @@ func ParseIPv4Header(b []byte) (*ipv4.Header, error) {
} }
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":
h.TotalLen = int(nativeEndian.Uint16(b[2:4])) h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
case "freebsd": case "freebsd":
if freebsdVersion >= 1000000 { if freebsdVersion >= 1000000 {
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4])) h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))
} else { } else {
h.TotalLen = int(nativeEndian.Uint16(b[2:4])) h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
} }
default: default:
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4])) h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))

View File

@ -11,72 +11,65 @@ import (
"runtime" "runtime"
"testing" "testing"
"golang.org/x/net/internal/socket"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
type ipv4HeaderTest struct {
wireHeaderFromKernel [ipv4.HeaderLen]byte
wireHeaderFromTradBSDKernel [ipv4.HeaderLen]byte
Header *ipv4.Header
}
var ipv4HeaderLittleEndianTest = ipv4HeaderTest{
// TODO(mikio): Add platform dependent wire header formats when
// we support new platforms.
wireHeaderFromKernel: [ipv4.HeaderLen]byte{
0x45, 0x01, 0xbe, 0xef,
0xca, 0xfe, 0x45, 0xdc,
0xff, 0x01, 0xde, 0xad,
172, 16, 254, 254,
192, 168, 0, 1,
},
wireHeaderFromTradBSDKernel: [ipv4.HeaderLen]byte{
0x45, 0x01, 0xef, 0xbe,
0xca, 0xfe, 0x45, 0xdc,
0xff, 0x01, 0xde, 0xad,
172, 16, 254, 254,
192, 168, 0, 1,
},
Header: &ipv4.Header{
Version: ipv4.Version,
Len: ipv4.HeaderLen,
TOS: 1,
TotalLen: 0xbeef,
ID: 0xcafe,
Flags: ipv4.DontFragment,
FragOff: 1500,
TTL: 255,
Protocol: 1,
Checksum: 0xdead,
Src: net.IPv4(172, 16, 254, 254),
Dst: net.IPv4(192, 168, 0, 1),
},
}
func TestParseIPv4Header(t *testing.T) { func TestParseIPv4Header(t *testing.T) {
tt := &ipv4HeaderLittleEndianTest switch socket.NativeEndian {
if nativeEndian != binary.LittleEndian { case binary.LittleEndian:
t.Skip("no test for non-little endian machine yet") t.Run("LittleEndian", func(t *testing.T) {
} // TODO(mikio): Add platform dependent wire
// header formats when we support new
var wh []byte // platforms.
switch runtime.GOOS { wireHeaderFromKernel := [ipv4.HeaderLen]byte{
case "darwin": 0x45, 0x01, 0xbe, 0xef,
wh = tt.wireHeaderFromTradBSDKernel[:] 0xca, 0xfe, 0x45, 0xdc,
case "freebsd": 0xff, 0x01, 0xde, 0xad,
if freebsdVersion >= 1000000 { 172, 16, 254, 254,
wh = tt.wireHeaderFromKernel[:] 192, 168, 0, 1,
} else { }
wh = tt.wireHeaderFromTradBSDKernel[:] wireHeaderFromTradBSDKernel := [ipv4.HeaderLen]byte{
} 0x45, 0x01, 0xef, 0xbe,
default: 0xca, 0xfe, 0x45, 0xdc,
wh = tt.wireHeaderFromKernel[:] 0xff, 0x01, 0xde, 0xad,
} 172, 16, 254, 254,
h, err := ParseIPv4Header(wh) 192, 168, 0, 1,
if err != nil { }
t.Fatal(err) th := &ipv4.Header{
} Version: ipv4.Version,
if !reflect.DeepEqual(h, tt.Header) { Len: ipv4.HeaderLen,
t.Fatalf("got %#v; want %#v", h, tt.Header) TOS: 1,
TotalLen: 0xbeef,
ID: 0xcafe,
Flags: ipv4.DontFragment,
FragOff: 1500,
TTL: 255,
Protocol: 1,
Checksum: 0xdead,
Src: net.IPv4(172, 16, 254, 254),
Dst: net.IPv4(192, 168, 0, 1),
}
var wh []byte
switch runtime.GOOS {
case "darwin":
wh = wireHeaderFromTradBSDKernel[:]
case "freebsd":
if freebsdVersion >= 1000000 {
wh = wireHeaderFromKernel[:]
} else {
wh = wireHeaderFromTradBSDKernel[:]
}
default:
wh = wireHeaderFromKernel[:]
}
h, err := ParseIPv4Header(wh)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(h, th) {
t.Fatalf("got %#v; want %#v", h, th)
}
})
} }
} }

View File

@ -11,6 +11,7 @@
// ICMP extensions for MPLS are defined in RFC 4950. // ICMP extensions for MPLS are defined in RFC 4950.
// ICMP extensions for interface and next-hop identification are // ICMP extensions for interface and next-hop identification are
// defined in RFC 5837. // defined in RFC 5837.
// PROBE: A utility for probing interfaces is defined in RFC 8335.
package icmp // import "golang.org/x/net/icmp" package icmp // import "golang.org/x/net/icmp"
import ( import (
@ -107,21 +108,25 @@ func (m *Message) Marshal(psh []byte) ([]byte, error) {
return b[len(psh):], nil return b[len(psh):], nil
} }
var parseFns = map[Type]func(int, []byte) (MessageBody, error){ var parseFns = map[Type]func(int, Type, []byte) (MessageBody, error){
ipv4.ICMPTypeDestinationUnreachable: parseDstUnreach, ipv4.ICMPTypeDestinationUnreachable: parseDstUnreach,
ipv4.ICMPTypeTimeExceeded: parseTimeExceeded, ipv4.ICMPTypeTimeExceeded: parseTimeExceeded,
ipv4.ICMPTypeParameterProblem: parseParamProb, ipv4.ICMPTypeParameterProblem: parseParamProb,
ipv4.ICMPTypeEcho: parseEcho, ipv4.ICMPTypeEcho: parseEcho,
ipv4.ICMPTypeEchoReply: parseEcho, ipv4.ICMPTypeEchoReply: parseEcho,
ipv4.ICMPTypeExtendedEchoRequest: parseExtendedEchoRequest,
ipv4.ICMPTypeExtendedEchoReply: parseExtendedEchoReply,
ipv6.ICMPTypeDestinationUnreachable: parseDstUnreach, ipv6.ICMPTypeDestinationUnreachable: parseDstUnreach,
ipv6.ICMPTypePacketTooBig: parsePacketTooBig, ipv6.ICMPTypePacketTooBig: parsePacketTooBig,
ipv6.ICMPTypeTimeExceeded: parseTimeExceeded, ipv6.ICMPTypeTimeExceeded: parseTimeExceeded,
ipv6.ICMPTypeParameterProblem: parseParamProb, ipv6.ICMPTypeParameterProblem: parseParamProb,
ipv6.ICMPTypeEchoRequest: parseEcho, ipv6.ICMPTypeEchoRequest: parseEcho,
ipv6.ICMPTypeEchoReply: parseEcho, ipv6.ICMPTypeEchoReply: parseEcho,
ipv6.ICMPTypeExtendedEchoRequest: parseExtendedEchoRequest,
ipv6.ICMPTypeExtendedEchoReply: parseExtendedEchoReply,
} }
// ParseMessage parses b as an ICMP message. // ParseMessage parses b as an ICMP message.
@ -143,7 +148,7 @@ func ParseMessage(proto int, b []byte) (*Message, error) {
if fn, ok := parseFns[m.Type]; !ok { if fn, ok := parseFns[m.Type]; !ok {
m.Body, err = parseDefaultMessageBody(proto, b[4:]) m.Body, err = parseDefaultMessageBody(proto, b[4:])
} else { } else {
m.Body, err = fn(proto, b[4:]) m.Body, err = fn(proto, m.Type, b[4:])
} }
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -15,120 +15,141 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
var marshalAndParseMessageForIPv4Tests = []icmp.Message{ func TestMarshalAndParseMessage(t *testing.T) {
{ fn := func(t *testing.T, proto int, tms []icmp.Message) {
Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15, var pshs [][]byte
Body: &icmp.DstUnreach{ switch proto {
Data: []byte("ERROR-INVOKING-PACKET"), case iana.ProtocolICMP:
}, pshs = [][]byte{nil}
}, case iana.ProtocolIPv6ICMP:
{ pshs = [][]byte{
Type: ipv4.ICMPTypeTimeExceeded, Code: 1, icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1")),
Body: &icmp.TimeExceeded{ nil,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv4.ICMPTypeParameterProblem, Code: 2,
Body: &icmp.ParamProb{
Pointer: 8,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: 1, Seq: 2,
Data: []byte("HELLO-R-U-THERE"),
},
},
{
Type: ipv4.ICMPTypePhoturis,
Body: &icmp.DefaultMessageBody{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
}
func TestMarshalAndParseMessageForIPv4(t *testing.T) {
for i, tt := range marshalAndParseMessageForIPv4Tests {
b, err := tt.Marshal(nil)
if err != nil {
t.Fatal(err)
}
m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
if err != nil {
t.Fatal(err)
}
if m.Type != tt.Type || m.Code != tt.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt)
}
if !reflect.DeepEqual(m.Body, tt.Body) {
t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body)
}
}
}
var marshalAndParseMessageForIPv6Tests = []icmp.Message{
{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
Body: &icmp.DstUnreach{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypePacketTooBig, Code: 0,
Body: &icmp.PacketTooBig{
MTU: 1<<16 - 1,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeTimeExceeded, Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeParameterProblem, Code: 2,
Body: &icmp.ParamProb{
Pointer: 8,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: 1, Seq: 2,
Data: []byte("HELLO-R-U-THERE"),
},
},
{
Type: ipv6.ICMPTypeDuplicateAddressConfirmation,
Body: &icmp.DefaultMessageBody{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
}
func TestMarshalAndParseMessageForIPv6(t *testing.T) {
pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
for i, tt := range marshalAndParseMessageForIPv6Tests {
for _, psh := range [][]byte{pshicmp, nil} {
b, err := tt.Marshal(psh)
if err != nil {
t.Fatal(err)
} }
m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b) }
if err != nil { for i, tm := range tms {
t.Fatal(err) for _, psh := range pshs {
} b, err := tm.Marshal(psh)
if m.Type != tt.Type || m.Code != tt.Code { if err != nil {
t.Errorf("#%v: got %v; want %v", i, m, &tt) t.Fatal(err)
} }
if !reflect.DeepEqual(m.Body, tt.Body) { m, err := icmp.ParseMessage(proto, b)
t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body) if err != nil {
t.Fatal(err)
}
if m.Type != tm.Type || m.Code != tm.Code {
t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
}
if !reflect.DeepEqual(m.Body, tm.Body) {
t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
}
} }
} }
} }
t.Run("IPv4", func(t *testing.T) {
fn(t, iana.ProtocolICMP,
[]icmp.Message{
{
Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15,
Body: &icmp.DstUnreach{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv4.ICMPTypeTimeExceeded, Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv4.ICMPTypeParameterProblem, Code: 2,
Body: &icmp.ParamProb{
Pointer: 8,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: 1, Seq: 2,
Data: []byte("HELLO-R-U-THERE"),
},
},
{
Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2,
},
},
{
Type: ipv4.ICMPTypeExtendedEchoReply, Code: 0,
Body: &icmp.ExtendedEchoReply{
State: 4 /* Delay */, Active: true, IPv4: true,
},
},
{
Type: ipv4.ICMPTypePhoturis,
Body: &icmp.DefaultMessageBody{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
})
})
t.Run("IPv6", func(t *testing.T) {
fn(t, iana.ProtocolIPv6ICMP,
[]icmp.Message{
{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
Body: &icmp.DstUnreach{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypePacketTooBig, Code: 0,
Body: &icmp.PacketTooBig{
MTU: 1<<16 - 1,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeTimeExceeded, Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeParameterProblem, Code: 2,
Body: &icmp.ParamProb{
Pointer: 8,
Data: []byte("ERROR-INVOKING-PACKET"),
},
},
{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: 1, Seq: 2,
Data: []byte("HELLO-R-U-THERE"),
},
},
{
Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2,
},
},
{
Type: ipv6.ICMPTypeExtendedEchoReply, Code: 0,
Body: &icmp.ExtendedEchoReply{
State: 5 /* Probe */, Active: true, IPv6: true,
},
},
{
Type: ipv6.ICMPTypeDuplicateAddressConfirmation,
Body: &icmp.DefaultMessageBody{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
})
})
} }

View File

@ -10,12 +10,14 @@ import "golang.org/x/net/internal/iana"
// exts as extensions, and returns a required length for message body // exts as extensions, and returns a required length for message body
// and a required length for a padded original datagram in wire // and a required length for a padded original datagram in wire
// format. // format.
func multipartMessageBodyDataLen(proto int, b []byte, exts []Extension) (bodyLen, dataLen int) { func multipartMessageBodyDataLen(proto int, withOrigDgram bool, b []byte, exts []Extension) (bodyLen, dataLen int) {
for _, ext := range exts { for _, ext := range exts {
bodyLen += ext.Len(proto) bodyLen += ext.Len(proto)
} }
if bodyLen > 0 { if bodyLen > 0 {
dataLen = multipartMessageOrigDatagramLen(proto, b) if withOrigDgram {
dataLen = multipartMessageOrigDatagramLen(proto, b)
}
bodyLen += 4 // length of extension header bodyLen += 4 // length of extension header
} else { } else {
dataLen = len(b) dataLen = len(b)
@ -50,8 +52,8 @@ func multipartMessageOrigDatagramLen(proto int, b []byte) int {
// marshalMultipartMessageBody takes data as an original datagram and // marshalMultipartMessageBody takes data as an original datagram and
// exts as extesnsions, and returns a binary encoding of message body. // exts as extesnsions, and returns a binary encoding of message body.
// It can be used for non-multipart message bodies when exts is nil. // It can be used for non-multipart message bodies when exts is nil.
func marshalMultipartMessageBody(proto int, data []byte, exts []Extension) ([]byte, error) { func marshalMultipartMessageBody(proto int, withOrigDgram bool, data []byte, exts []Extension) ([]byte, error) {
bodyLen, dataLen := multipartMessageBodyDataLen(proto, data, exts) bodyLen, dataLen := multipartMessageBodyDataLen(proto, withOrigDgram, data, exts)
b := make([]byte, 4+bodyLen) b := make([]byte, 4+bodyLen)
copy(b[4:], data) copy(b[4:], data)
off := dataLen + 4 off := dataLen + 4
@ -71,16 +73,23 @@ func marshalMultipartMessageBody(proto int, data []byte, exts []Extension) ([]by
return nil, err return nil, err
} }
off += ext.Len(proto) off += ext.Len(proto)
case *InterfaceIdent:
if err := ext.marshal(proto, b[off:]); err != nil {
return nil, err
}
off += ext.Len(proto)
} }
} }
s := checksum(b[dataLen+4:]) s := checksum(b[dataLen+4:])
b[dataLen+4+2] ^= byte(s) b[dataLen+4+2] ^= byte(s)
b[dataLen+4+3] ^= byte(s >> 8) b[dataLen+4+3] ^= byte(s >> 8)
switch proto { if withOrigDgram {
case iana.ProtocolICMP: switch proto {
b[1] = byte(dataLen / 4) case iana.ProtocolICMP:
case iana.ProtocolIPv6ICMP: b[1] = byte(dataLen / 4)
b[0] = byte(dataLen / 8) case iana.ProtocolIPv6ICMP:
b[0] = byte(dataLen / 8)
}
} }
} }
return b, nil return b, nil
@ -88,7 +97,7 @@ func marshalMultipartMessageBody(proto int, data []byte, exts []Extension) ([]by
// parseMultipartMessageBody parses b as either a non-multipart // parseMultipartMessageBody parses b as either a non-multipart
// message body or a multipart message body. // message body or a multipart message body.
func parseMultipartMessageBody(proto int, b []byte) ([]byte, []Extension, error) { func parseMultipartMessageBody(proto int, typ Type, b []byte) ([]byte, []Extension, error) {
var l int var l int
switch proto { switch proto {
case iana.ProtocolICMP: case iana.ProtocolICMP:
@ -99,11 +108,14 @@ func parseMultipartMessageBody(proto int, b []byte) ([]byte, []Extension, error)
if len(b) == 4 { if len(b) == 4 {
return nil, nil, nil return nil, nil, nil
} }
exts, l, err := parseExtensions(b[4:], l) exts, l, err := parseExtensions(typ, b[4:], l)
if err != nil { if err != nil {
l = len(b) - 4 l = len(b) - 4
} }
data := make([]byte, l) var data []byte
copy(data, b[4:]) if l > 0 {
data = make([]byte, l)
copy(data, b[4:])
}
return data, exts, nil return data, exts, nil
} }

View File

@ -5,6 +5,7 @@
package icmp_test package icmp_test
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
@ -16,425 +17,557 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
var marshalAndParseMultipartMessageForIPv4Tests = []icmp.Message{ func TestMarshalAndParseMultipartMessage(t *testing.T) {
{ fn := func(t *testing.T, proto int, tm icmp.Message) error {
Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15, b, err := tm.Marshal(nil)
Body: &icmp.DstUnreach{
Data: []byte("ERROR-INVOKING-PACKET"),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{
Class: 1,
Type: 1,
Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 1).To4(),
},
},
},
},
},
{
Type: ipv4.ICMPTypeTimeExceeded, Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("ERROR-INVOKING-PACKET"),
Extensions: []icmp.Extension{
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 1).To4(),
},
},
&icmp.MPLSLabelStack{
Class: 1,
Type: 1,
Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
},
},
},
{
Type: ipv4.ICMPTypeParameterProblem, Code: 2,
Body: &icmp.ParamProb{
Pointer: 8,
Data: []byte("ERROR-INVOKING-PACKET"),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{
Class: 1,
Type: 1,
Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 1).To4(),
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x2f,
Interface: &net.Interface{
Index: 16,
Name: "en102",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 2).To4(),
},
},
},
},
},
}
func TestMarshalAndParseMultipartMessageForIPv4(t *testing.T) {
for i, tt := range marshalAndParseMultipartMessageForIPv4Tests {
b, err := tt.Marshal(nil)
if err != nil { if err != nil {
t.Fatal(err) return err
} }
if b[5] != 32 { switch tm.Type {
t.Errorf("#%v: got %v; want 32", i, b[5]) case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
default:
switch proto {
case iana.ProtocolICMP:
if b[5] != 32 {
return fmt.Errorf("got %d; want 32", b[5])
}
case iana.ProtocolIPv6ICMP:
if b[4] != 16 {
return fmt.Errorf("got %d; want 16", b[4])
}
default:
return fmt.Errorf("unknown protocol: %d", proto)
}
} }
m, err := icmp.ParseMessage(iana.ProtocolICMP, b) m, err := icmp.ParseMessage(proto, b)
if err != nil { if err != nil {
t.Fatal(err) return err
} }
if m.Type != tt.Type || m.Code != tt.Code { if m.Type != tm.Type || m.Code != tm.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt) return fmt.Errorf("got %v; want %v", m, &tm)
} }
switch m.Type { switch m.Type {
case ipv4.ICMPTypeDestinationUnreachable: case ipv4.ICMPTypeExtendedEchoRequest, ipv6.ICMPTypeExtendedEchoRequest:
got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach) got, want := m.Body.(*icmp.ExtendedEchoRequest), tm.Body.(*icmp.ExtendedEchoRequest)
if !reflect.DeepEqual(got.Extensions, want.Extensions) { if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
case ipv4.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tm.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
} }
if len(got.Data) != 128 { if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data)) return fmt.Errorf("got %d; want 128", len(got.Data))
} }
case ipv4.ICMPTypeTimeExceeded: case ipv4.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded) got, want := m.Body.(*icmp.TimeExceeded), tm.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) { if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) return errors.New(dumpExtensions(got.Extensions, want.Extensions))
} }
if len(got.Data) != 128 { if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data)) return fmt.Errorf("got %d; want 128", len(got.Data))
} }
case ipv4.ICMPTypeParameterProblem: case ipv4.ICMPTypeParameterProblem:
got, want := m.Body.(*icmp.ParamProb), tt.Body.(*icmp.ParamProb) got, want := m.Body.(*icmp.ParamProb), tm.Body.(*icmp.ParamProb)
if !reflect.DeepEqual(got.Extensions, want.Extensions) { if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) return errors.New(dumpExtensions(got.Extensions, want.Extensions))
} }
if len(got.Data) != 128 { if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data)) return fmt.Errorf("got %d; want 128", len(got.Data))
} }
case ipv6.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tm.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
case ipv6.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tm.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
default:
return fmt.Errorf("unknown message type: %v", m.Type)
} }
return nil
} }
}
var marshalAndParseMultipartMessageForIPv6Tests = []icmp.Message{ t.Run("IPv4", func(t *testing.T) {
{ for i, tm := range []icmp.Message{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6, {
Body: &icmp.DstUnreach{ Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15,
Data: []byte("ERROR-INVOKING-PACKET"), Body: &icmp.DstUnreach{
Extensions: []icmp.Extension{ Data: []byte("ERROR-INVOKING-PACKET"),
&icmp.MPLSLabelStack{ Extensions: []icmp.Extension{
Class: 1, &icmp.MPLSLabelStack{
Type: 1, Class: 1,
Labels: []icmp.MPLSLabel{ Type: 1,
{ Labels: []icmp.MPLSLabel{
Label: 16014, {
TC: 0x4, Label: 16014,
S: true, TC: 0x4,
TTL: 255, S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 1).To4(),
},
}, },
}, },
}, },
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en101",
},
},
}, },
}, {
}, Type: ipv4.ICMPTypeTimeExceeded, Code: 1,
{ Body: &icmp.TimeExceeded{
Type: ipv6.ICMPTypeTimeExceeded, Code: 1, Data: []byte("ERROR-INVOKING-PACKET"),
Body: &icmp.TimeExceeded{ Extensions: []icmp.Extension{
Data: []byte("ERROR-INVOKING-PACKET"), &icmp.InterfaceInfo{
Extensions: []icmp.Extension{ Class: 2,
&icmp.InterfaceInfo{ Type: 0x0f,
Class: 2, Interface: &net.Interface{
Type: 0x0f, Index: 15,
Interface: &net.Interface{ Name: "en101",
Index: 15, MTU: 8192,
Name: "en101", },
MTU: 8192, Addr: &net.IPAddr{
}, IP: net.IPv4(192, 168, 0, 1).To4(),
Addr: &net.IPAddr{ },
IP: net.ParseIP("fe80::1"), },
Zone: "en101", &icmp.MPLSLabelStack{
}, Class: 1,
}, Type: 1,
&icmp.MPLSLabelStack{ Labels: []icmp.MPLSLabel{
Class: 1, {
Type: 1, Label: 16014,
Labels: []icmp.MPLSLabel{ TC: 0x4,
{ S: true,
Label: 16014, TTL: 255,
TC: 0x4, },
S: true, },
TTL: 255,
}, },
}, },
}, },
&icmp.InterfaceInfo{ },
Class: 2, {
Type: 0x2f, Type: ipv4.ICMPTypeParameterProblem, Code: 2,
Interface: &net.Interface{ Body: &icmp.ParamProb{
Index: 16, Pointer: 8,
Name: "en102", Data: []byte("ERROR-INVOKING-PACKET"),
MTU: 8192, Extensions: []icmp.Extension{
}, &icmp.MPLSLabelStack{
Addr: &net.IPAddr{ Class: 1,
IP: net.ParseIP("fe80::1"), Type: 1,
Zone: "en102", Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 1).To4(),
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x2f,
Interface: &net.Interface{
Index: 16,
Name: "en102",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.IPv4(192, 168, 0, 2).To4(),
},
},
}, },
}, },
}, },
}, {
}, Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
} Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2, Local: true,
func TestMarshalAndParseMultipartMessageForIPv6(t *testing.T) { Extensions: []icmp.Extension{
pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1")) &icmp.InterfaceIdent{
for i, tt := range marshalAndParseMultipartMessageForIPv6Tests { Class: 3,
for _, psh := range [][]byte{pshicmp, nil} { Type: 1,
b, err := tt.Marshal(psh) Name: "en101",
if err != nil { },
t.Fatal(err) },
} },
if b[4] != 16 { },
t.Errorf("#%v: got %v; want 16", i, b[4]) {
} Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b) Body: &icmp.ExtendedEchoRequest{
if err != nil { ID: 1, Seq: 2, Local: true,
t.Fatal(err) Extensions: []icmp.Extension{
} &icmp.InterfaceIdent{
if m.Type != tt.Type || m.Code != tt.Code { Class: 3,
t.Errorf("#%v: got %v; want %v", i, m, &tt) Type: 2,
} Index: 911,
switch m.Type { },
case ipv6.ICMPTypeDestinationUnreachable: &icmp.InterfaceIdent{
got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach) Class: 3,
if !reflect.DeepEqual(got.Extensions, want.Extensions) { Type: 1,
t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) Name: "en101",
} },
if len(got.Data) != 128 { },
t.Errorf("#%v: got %v; want 128", i, len(got.Data)) },
} },
case ipv6.ICMPTypeTimeExceeded: {
got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded) Type: ipv4.ICMPTypeExtendedEchoRequest, Code: 0,
if !reflect.DeepEqual(got.Extensions, want.Extensions) { Body: &icmp.ExtendedEchoRequest{
t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) ID: 1, Seq: 2,
} Extensions: []icmp.Extension{
if len(got.Data) != 128 { &icmp.InterfaceIdent{
t.Errorf("#%v: got %v; want 128", i, len(got.Data)) Class: 3,
} Type: 3,
AFI: iana.AddrFamily48bitMAC,
Addr: []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab},
},
},
},
},
} {
if err := fn(t, iana.ProtocolICMP, tm); err != nil {
t.Errorf("#%d: %v", i, err)
} }
} }
} })
t.Run("IPv6", func(t *testing.T) {
for i, tm := range []icmp.Message{
{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
Body: &icmp.DstUnreach{
Data: []byte("ERROR-INVOKING-PACKET"),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{
Class: 1,
Type: 1,
Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en101",
},
},
},
},
},
{
Type: ipv6.ICMPTypeTimeExceeded, Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("ERROR-INVOKING-PACKET"),
Extensions: []icmp.Extension{
&icmp.InterfaceInfo{
Class: 2,
Type: 0x0f,
Interface: &net.Interface{
Index: 15,
Name: "en101",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en101",
},
},
&icmp.MPLSLabelStack{
Class: 1,
Type: 1,
Labels: []icmp.MPLSLabel{
{
Label: 16014,
TC: 0x4,
S: true,
TTL: 255,
},
},
},
&icmp.InterfaceInfo{
Class: 2,
Type: 0x2f,
Interface: &net.Interface{
Index: 16,
Name: "en102",
MTU: 8192,
},
Addr: &net.IPAddr{
IP: net.ParseIP("fe80::1"),
Zone: "en102",
},
},
},
},
},
{
Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2, Local: true,
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Class: 3,
Type: 1,
Name: "en101",
},
},
},
},
{
Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2, Local: true,
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Class: 3,
Type: 1,
Name: "en101",
},
&icmp.InterfaceIdent{
Class: 3,
Type: 2,
Index: 911,
},
},
},
},
{
Type: ipv6.ICMPTypeExtendedEchoRequest, Code: 0,
Body: &icmp.ExtendedEchoRequest{
ID: 1, Seq: 2,
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Class: 3,
Type: 3,
AFI: iana.AddrFamilyIPv4,
Addr: []byte{192, 0, 2, 1},
},
},
},
},
} {
if err := fn(t, iana.ProtocolIPv6ICMP, tm); err != nil {
t.Errorf("#%d: %v", i, err)
}
}
})
} }
func dumpExtensions(i int, gotExts, wantExts []icmp.Extension) string { func dumpExtensions(gotExts, wantExts []icmp.Extension) string {
var s string var s string
for j, got := range gotExts { for i, got := range gotExts {
switch got := got.(type) { switch got := got.(type) {
case *icmp.MPLSLabelStack: case *icmp.MPLSLabelStack:
want := wantExts[j].(*icmp.MPLSLabelStack) want := wantExts[i].(*icmp.MPLSLabelStack)
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
s += fmt.Sprintf("#%v/%v: got %#v; want %#v\n", i, j, got, want) s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
} }
case *icmp.InterfaceInfo: case *icmp.InterfaceInfo:
want := wantExts[j].(*icmp.InterfaceInfo) want := wantExts[i].(*icmp.InterfaceInfo)
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
s += fmt.Sprintf("#%v/%v: got %#v, %#v, %#v; want %#v, %#v, %#v\n", i, j, got, got.Interface, got.Addr, want, want.Interface, want.Addr) s += fmt.Sprintf("#%d: got %#v, %#v, %#v; want %#v, %#v, %#v\n", i, got, got.Interface, got.Addr, want, want.Interface, want.Addr)
}
case *icmp.InterfaceIdent:
want := wantExts[i].(*icmp.InterfaceIdent)
if !reflect.DeepEqual(got, want) {
s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
} }
} }
} }
if len(s) == 0 {
return "<nil>"
}
return s[:len(s)-1] return s[:len(s)-1]
} }
var multipartMessageBodyLenTests = []struct {
proto int
in icmp.MessageBody
out int
}{
{
iana.ProtocolICMP,
&icmp.DstUnreach{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // unused and original datagram
},
{
iana.ProtocolICMP,
&icmp.TimeExceeded{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // unused and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // [pointer, unused] and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, ipv4.HeaderLen),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [pointer, length, unused], extension header, object header, object payload, original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, 128),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [pointer, length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, 129),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 132, // [pointer, length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // unused and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.PacketTooBig{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // mtu and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.TimeExceeded{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // unused and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.ParamProb{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // pointer and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 127),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 128),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 129),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 136, // [length, unused], extension header, object header, object payload and original datagram
},
}
func TestMultipartMessageBodyLen(t *testing.T) { func TestMultipartMessageBodyLen(t *testing.T) {
for i, tt := range multipartMessageBodyLenTests { for i, tt := range []struct {
proto int
in icmp.MessageBody
out int
}{
{
iana.ProtocolICMP,
&icmp.DstUnreach{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // unused and original datagram
},
{
iana.ProtocolICMP,
&icmp.TimeExceeded{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // unused and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, ipv4.HeaderLen),
},
4 + ipv4.HeaderLen, // [pointer, unused] and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, ipv4.HeaderLen),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [pointer, length, unused], extension header, object header, object payload, original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, 128),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [pointer, length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolICMP,
&icmp.ParamProb{
Data: make([]byte, 129),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 132, // [pointer, length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // unused and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.PacketTooBig{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // mtu and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.TimeExceeded{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // unused and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.ParamProb{
Data: make([]byte, ipv6.HeaderLen),
},
4 + ipv6.HeaderLen, // pointer and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 127),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 128),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 128, // [length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolIPv6ICMP,
&icmp.DstUnreach{
Data: make([]byte, 129),
Extensions: []icmp.Extension{
&icmp.MPLSLabelStack{},
},
},
4 + 4 + 4 + 0 + 136, // [length, unused], extension header, object header, object payload and original datagram
},
{
iana.ProtocolICMP,
&icmp.ExtendedEchoRequest{},
4, // [id, seq, l-bit]
},
{
iana.ProtocolICMP,
&icmp.ExtendedEchoRequest{
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{},
},
},
4 + 4 + 4, // [id, seq, l-bit], extension header, object header
},
{
iana.ProtocolIPv6ICMP,
&icmp.ExtendedEchoRequest{
Extensions: []icmp.Extension{
&icmp.InterfaceIdent{
Type: 3,
AFI: iana.AddrFamilyNSAP,
Addr: []byte{0x49, 0x00, 0x01, 0xaa, 0xaa, 0xbb, 0xbb, 0xcc, 0xcc, 0x00},
},
},
},
4 + 4 + 4 + 16, // [id, seq, l-bit], extension header, object header, object payload
},
} {
if out := tt.in.Len(tt.proto); out != tt.out { if out := tt.in.Len(tt.proto); out != tt.out {
t.Errorf("#%d: got %d; want %d", i, out, tt.out) t.Errorf("#%d: got %d; want %d", i, out, tt.out)
} }

View File

@ -29,7 +29,7 @@ func (p *PacketTooBig) Marshal(proto int) ([]byte, error) {
} }
// parsePacketTooBig parses b as an ICMP packet too big message body. // parsePacketTooBig parses b as an ICMP packet too big message body.
func parsePacketTooBig(proto int, b []byte) (MessageBody, error) { func parsePacketTooBig(proto int, _ Type, b []byte) (MessageBody, error) {
bodyLen := len(b) bodyLen := len(b)
if bodyLen < 4 { if bodyLen < 4 {
return nil, errMessageTooShort return nil, errMessageTooShort

View File

@ -21,7 +21,7 @@ func (p *ParamProb) Len(proto int) int {
if p == nil { if p == nil {
return 0 return 0
} }
l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
return 4 + l return 4 + l
} }
@ -33,7 +33,7 @@ func (p *ParamProb) Marshal(proto int) ([]byte, error) {
copy(b[4:], p.Data) copy(b[4:], p.Data)
return b, nil return b, nil
} }
b, err := marshalMultipartMessageBody(proto, p.Data, p.Extensions) b, err := marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,7 +42,7 @@ func (p *ParamProb) Marshal(proto int) ([]byte, error) {
} }
// parseParamProb parses b as an ICMP parameter problem message body. // parseParamProb parses b as an ICMP parameter problem message body.
func parseParamProb(proto int, b []byte) (MessageBody, error) { func parseParamProb(proto int, typ Type, b []byte) (MessageBody, error) {
if len(b) < 4 { if len(b) < 4 {
return nil, errMessageTooShort return nil, errMessageTooShort
} }
@ -55,7 +55,7 @@ func parseParamProb(proto int, b []byte) (MessageBody, error) {
} }
p.Pointer = uintptr(b[0]) p.Pointer = uintptr(b[0])
var err error var err error
p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) p.Data, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,200 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package icmp_test
import (
"errors"
"fmt"
"net"
"os"
"runtime"
"sync"
"testing"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/internal/iana"
"golang.org/x/net/internal/nettest"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func googleAddr(c *icmp.PacketConn, protocol int) (net.Addr, error) {
const host = "www.google.com"
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
netaddr := func(ip net.IP) (net.Addr, error) {
switch c.LocalAddr().(type) {
case *net.UDPAddr:
return &net.UDPAddr{IP: ip}, nil
case *net.IPAddr:
return &net.IPAddr{IP: ip}, nil
default:
return nil, errors.New("neither UDPAddr nor IPAddr")
}
}
for _, ip := range ips {
switch protocol {
case iana.ProtocolICMP:
if ip.To4() != nil {
return netaddr(ip)
}
case iana.ProtocolIPv6ICMP:
if ip.To16() != nil && ip.To4() == nil {
return netaddr(ip)
}
}
}
return nil, errors.New("no A or AAAA record")
}
type pingTest struct {
network, address string
protocol int
mtype icmp.Type
}
var nonPrivilegedPingTests = []pingTest{
{"udp4", "0.0.0.0", iana.ProtocolICMP, ipv4.ICMPTypeEcho},
{"udp6", "::", iana.ProtocolIPv6ICMP, ipv6.ICMPTypeEchoRequest},
}
func TestNonPrivilegedPing(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
switch runtime.GOOS {
case "darwin":
case "linux":
t.Log("you may need to adjust the net.ipv4.ping_group_range kernel state")
default:
t.Skipf("not supported on %s", runtime.GOOS)
}
for i, tt := range nonPrivilegedPingTests {
if err := doPing(tt, i); err != nil {
t.Error(err)
}
}
}
var privilegedPingTests = []pingTest{
{"ip4:icmp", "0.0.0.0", iana.ProtocolICMP, ipv4.ICMPTypeEcho},
{"ip6:ipv6-icmp", "::", iana.ProtocolIPv6ICMP, ipv6.ICMPTypeEchoRequest},
}
func TestPrivilegedPing(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
if m, ok := nettest.SupportsRawIPSocket(); !ok {
t.Skip(m)
}
for i, tt := range privilegedPingTests {
if err := doPing(tt, i); err != nil {
t.Error(err)
}
}
}
func doPing(tt pingTest, seq int) error {
c, err := icmp.ListenPacket(tt.network, tt.address)
if err != nil {
return err
}
defer c.Close()
dst, err := googleAddr(c, tt.protocol)
if err != nil {
return err
}
if tt.network != "udp6" && tt.protocol == iana.ProtocolIPv6ICMP {
var f ipv6.ICMPFilter
f.SetAll(true)
f.Accept(ipv6.ICMPTypeDestinationUnreachable)
f.Accept(ipv6.ICMPTypePacketTooBig)
f.Accept(ipv6.ICMPTypeTimeExceeded)
f.Accept(ipv6.ICMPTypeParameterProblem)
f.Accept(ipv6.ICMPTypeEchoReply)
if err := c.IPv6PacketConn().SetICMPFilter(&f); err != nil {
return err
}
}
wm := icmp.Message{
Type: tt.mtype, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff, Seq: 1 << uint(seq),
Data: []byte("HELLO-R-U-THERE"),
},
}
wb, err := wm.Marshal(nil)
if err != nil {
return err
}
if n, err := c.WriteTo(wb, dst); err != nil {
return err
} else if n != len(wb) {
return fmt.Errorf("got %v; want %v", n, len(wb))
}
rb := make([]byte, 1500)
if err := c.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
return err
}
n, peer, err := c.ReadFrom(rb)
if err != nil {
return err
}
rm, err := icmp.ParseMessage(tt.protocol, rb[:n])
if err != nil {
return err
}
switch rm.Type {
case ipv4.ICMPTypeEchoReply, ipv6.ICMPTypeEchoReply:
return nil
default:
return fmt.Errorf("got %+v from %v; want echo reply", rm, peer)
}
}
func TestConcurrentNonPrivilegedListenPacket(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
switch runtime.GOOS {
case "darwin":
case "linux":
t.Log("you may need to adjust the net.ipv4.ping_group_range kernel state")
default:
t.Skipf("not supported on %s", runtime.GOOS)
}
network, address := "udp4", "127.0.0.1"
if !nettest.SupportsIPv4() {
network, address = "udp6", "::1"
}
const N = 1000
var wg sync.WaitGroup
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer wg.Done()
c, err := icmp.ListenPacket(network, address)
if err != nil {
t.Error(err)
return
}
c.Close()
}()
}
wg.Wait()
}

View File

@ -15,23 +15,23 @@ func (p *TimeExceeded) Len(proto int) int {
if p == nil { if p == nil {
return 0 return 0
} }
l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) l, _ := multipartMessageBodyDataLen(proto, true, p.Data, p.Extensions)
return 4 + l return 4 + l
} }
// Marshal implements the Marshal method of MessageBody interface. // Marshal implements the Marshal method of MessageBody interface.
func (p *TimeExceeded) Marshal(proto int) ([]byte, error) { func (p *TimeExceeded) Marshal(proto int) ([]byte, error) {
return marshalMultipartMessageBody(proto, p.Data, p.Extensions) return marshalMultipartMessageBody(proto, true, p.Data, p.Extensions)
} }
// parseTimeExceeded parses b as an ICMP time exceeded message body. // parseTimeExceeded parses b as an ICMP time exceeded message body.
func parseTimeExceeded(proto int, b []byte) (MessageBody, error) { func parseTimeExceeded(proto int, typ Type, b []byte) (MessageBody, error) {
if len(b) < 4 { if len(b) < 4 {
return nil, errMessageTooShort return nil, errMessageTooShort
} }
p := &TimeExceeded{} p := &TimeExceeded{}
var err error var err error
p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) p.Data, p.Extensions, err = parseMultipartMessageBody(proto, typ, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -51,6 +51,10 @@ func ExampleNew() {
idna.Transitional(true)) // Map ß -> ss idna.Transitional(true)) // Map ß -> ss
fmt.Println(p.ToASCII("*.faß.com")) fmt.Println(p.ToASCII("*.faß.com"))
// Lookup for registration. Also does not allow '*'.
p = idna.New(idna.ValidateForRegistration())
fmt.Println(p.ToUnicode("*.faß.com"))
// Set up a profile maps for lookup, but allows wild cards. // Set up a profile maps for lookup, but allows wild cards.
p = idna.New( p = idna.New(
idna.MapForLookup(), idna.MapForLookup(),
@ -60,6 +64,7 @@ func ExampleNew() {
// Output: // Output:
// *.xn--fa-hia.com <nil> // *.xn--fa-hia.com <nil>
// *.fass.com idna: disallowed rune U+002E // *.fass.com idna: disallowed rune U+002A
// *.faß.com idna: disallowed rune U+002A
// *.fass.com <nil> // *.fass.com <nil>
} }

126
vendor/golang.org/x/net/idna/idna.go generated vendored
View File

@ -21,6 +21,7 @@ import (
"unicode/utf8" "unicode/utf8"
"golang.org/x/text/secure/bidirule" "golang.org/x/text/secure/bidirule"
"golang.org/x/text/unicode/bidi"
"golang.org/x/text/unicode/norm" "golang.org/x/text/unicode/norm"
) )
@ -67,6 +68,15 @@ func VerifyDNSLength(verify bool) Option {
return func(o *options) { o.verifyDNSLength = verify } return func(o *options) { o.verifyDNSLength = verify }
} }
// RemoveLeadingDots removes leading label separators. Leading runes that map to
// dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well.
//
// This is the behavior suggested by the UTS #46 and is adopted by some
// browsers.
func RemoveLeadingDots(remove bool) Option {
return func(o *options) { o.removeLeadingDots = remove }
}
// ValidateLabels sets whether to check the mandatory label validation criteria // ValidateLabels sets whether to check the mandatory label validation criteria
// as defined in Section 5.4 of RFC 5891. This includes testing for correct use // as defined in Section 5.4 of RFC 5891. This includes testing for correct use
// of hyphens ('-'), normalization, validity of runes, and the context rules. // of hyphens ('-'), normalization, validity of runes, and the context rules.
@ -83,7 +93,7 @@ func ValidateLabels(enable bool) Option {
} }
} }
// StrictDomainName limits the set of permissable ASCII characters to those // StrictDomainName limits the set of permissible ASCII characters to those
// allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the // allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the
// hyphen). This is set by default for MapForLookup and ValidateForRegistration. // hyphen). This is set by default for MapForLookup and ValidateForRegistration.
// //
@ -137,10 +147,11 @@ func MapForLookup() Option {
} }
type options struct { type options struct {
transitional bool transitional bool
useSTD3Rules bool useSTD3Rules bool
validateLabels bool validateLabels bool
verifyDNSLength bool verifyDNSLength bool
removeLeadingDots bool
trie *idnaTrie trie *idnaTrie
@ -149,14 +160,14 @@ type options struct {
// mapping implements a validation and mapping step as defined in RFC 5895 // mapping implements a validation and mapping step as defined in RFC 5895
// or UTS 46, tailored to, for example, domain registration or lookup. // or UTS 46, tailored to, for example, domain registration or lookup.
mapping func(p *Profile, s string) (string, error) mapping func(p *Profile, s string) (mapped string, isBidi bool, err error)
// bidirule, if specified, checks whether s conforms to the Bidi Rule // bidirule, if specified, checks whether s conforms to the Bidi Rule
// defined in RFC 5893. // defined in RFC 5893.
bidirule func(s string) bool bidirule func(s string) bool
} }
// A Profile defines the configuration of a IDNA mapper. // A Profile defines the configuration of an IDNA mapper.
type Profile struct { type Profile struct {
options options
} }
@ -289,12 +300,16 @@ func (e runeError) Error() string {
// see http://www.unicode.org/reports/tr46. // see http://www.unicode.org/reports/tr46.
func (p *Profile) process(s string, toASCII bool) (string, error) { func (p *Profile) process(s string, toASCII bool) (string, error) {
var err error var err error
var isBidi bool
if p.mapping != nil { if p.mapping != nil {
s, err = p.mapping(p, s) s, isBidi, err = p.mapping(p, s)
} }
// Remove leading empty labels. // Remove leading empty labels.
for ; len(s) > 0 && s[0] == '.'; s = s[1:] { if p.removeLeadingDots {
for ; len(s) > 0 && s[0] == '.'; s = s[1:] {
}
} }
// TODO: allow for a quick check of the tables data.
// It seems like we should only create this error on ToASCII, but the // It seems like we should only create this error on ToASCII, but the
// UTS 46 conformance tests suggests we should always check this. // UTS 46 conformance tests suggests we should always check this.
if err == nil && p.verifyDNSLength && s == "" { if err == nil && p.verifyDNSLength && s == "" {
@ -320,6 +335,7 @@ func (p *Profile) process(s string, toASCII bool) (string, error) {
// Spec says keep the old label. // Spec says keep the old label.
continue continue
} }
isBidi = isBidi || bidirule.DirectionString(u) != bidi.LeftToRight
labels.set(u) labels.set(u)
if err == nil && p.validateLabels { if err == nil && p.validateLabels {
err = p.fromPuny(p, u) err = p.fromPuny(p, u)
@ -334,6 +350,14 @@ func (p *Profile) process(s string, toASCII bool) (string, error) {
err = p.validateLabel(label) err = p.validateLabel(label)
} }
} }
if isBidi && p.bidirule != nil && err == nil {
for labels.reset(); !labels.done(); labels.next() {
if !p.bidirule(labels.label()) {
err = &labelError{s, "B"}
break
}
}
}
if toASCII { if toASCII {
for labels.reset(); !labels.done(); labels.next() { for labels.reset(); !labels.done(); labels.next() {
label := labels.label() label := labels.label()
@ -365,41 +389,77 @@ func (p *Profile) process(s string, toASCII bool) (string, error) {
return s, err return s, err
} }
func normalize(p *Profile, s string) (string, error) { func normalize(p *Profile, s string) (mapped string, isBidi bool, err error) {
return norm.NFC.String(s), nil // TODO: consider first doing a quick check to see if any of these checks
// need to be done. This will make it slower in the general case, but
// faster in the common case.
mapped = norm.NFC.String(s)
isBidi = bidirule.DirectionString(mapped) == bidi.RightToLeft
return mapped, isBidi, nil
} }
func validateRegistration(p *Profile, s string) (string, error) { func validateRegistration(p *Profile, s string) (idem string, bidi bool, err error) {
// TODO: filter need for normalization in loop below.
if !norm.NFC.IsNormalString(s) { if !norm.NFC.IsNormalString(s) {
return s, &labelError{s, "V1"} return s, false, &labelError{s, "V1"}
} }
var err error
for i := 0; i < len(s); { for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:]) v, sz := trie.lookupString(s[i:])
i += sz if sz == 0 {
return s, bidi, runeError(utf8.RuneError)
}
bidi = bidi || info(v).isBidi(s[i:])
// Copy bytes not copied so far. // Copy bytes not copied so far.
switch p.simplify(info(v).category()) { switch p.simplify(info(v).category()) {
// TODO: handle the NV8 defined in the Unicode idna data set to allow // TODO: handle the NV8 defined in the Unicode idna data set to allow
// for strict conformance to IDNA2008. // for strict conformance to IDNA2008.
case valid, deviation: case valid, deviation:
case disallowed, mapped, unknown, ignored: case disallowed, mapped, unknown, ignored:
if err == nil { r, _ := utf8.DecodeRuneInString(s[i:])
r, _ := utf8.DecodeRuneInString(s[i:]) return s, bidi, runeError(r)
err = runeError(r)
}
} }
i += sz
} }
return s, err return s, bidi, nil
} }
func validateAndMap(p *Profile, s string) (string, error) { func (c info) isBidi(s string) bool {
if !c.isMapped() {
return c&attributesMask == rtl
}
// TODO: also store bidi info for mapped data. This is possible, but a bit
// cumbersome and not for the common case.
p, _ := bidi.LookupString(s)
switch p.Class() {
case bidi.R, bidi.AL, bidi.AN:
return true
}
return false
}
func validateAndMap(p *Profile, s string) (vm string, bidi bool, err error) {
var ( var (
err error b []byte
b []byte k int
k int
) )
// combinedInfoBits contains the or-ed bits of all runes. We use this
// to derive the mayNeedNorm bit later. This may trigger normalization
// overeagerly, but it will not do so in the common case. The end result
// is another 10% saving on BenchmarkProfile for the common case.
var combinedInfoBits info
for i := 0; i < len(s); { for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:]) v, sz := trie.lookupString(s[i:])
if sz == 0 {
b = append(b, s[k:i]...)
b = append(b, "\ufffd"...)
k = len(s)
if err == nil {
err = runeError(utf8.RuneError)
}
break
}
combinedInfoBits |= info(v)
bidi = bidi || info(v).isBidi(s[i:])
start := i start := i
i += sz i += sz
// Copy bytes not copied so far. // Copy bytes not copied so far.
@ -408,7 +468,7 @@ func validateAndMap(p *Profile, s string) (string, error) {
continue continue
case disallowed: case disallowed:
if err == nil { if err == nil {
r, _ := utf8.DecodeRuneInString(s[i:]) r, _ := utf8.DecodeRuneInString(s[start:])
err = runeError(r) err = runeError(r)
} }
continue continue
@ -426,7 +486,9 @@ func validateAndMap(p *Profile, s string) (string, error) {
} }
if k == 0 { if k == 0 {
// No changes so far. // No changes so far.
s = norm.NFC.String(s) if combinedInfoBits&mayNeedNorm != 0 {
s = norm.NFC.String(s)
}
} else { } else {
b = append(b, s[k:]...) b = append(b, s[k:]...)
if norm.NFC.QuickSpan(b) != len(b) { if norm.NFC.QuickSpan(b) != len(b) {
@ -435,7 +497,7 @@ func validateAndMap(p *Profile, s string) (string, error) {
// TODO: the punycode converters require strings as input. // TODO: the punycode converters require strings as input.
s = string(b) s = string(b)
} }
return s, err return s, bidi, err
} }
// A labelIter allows iterating over domain name labels. // A labelIter allows iterating over domain name labels.
@ -530,8 +592,13 @@ func validateFromPunycode(p *Profile, s string) error {
if !norm.NFC.IsNormalString(s) { if !norm.NFC.IsNormalString(s) {
return &labelError{s, "V1"} return &labelError{s, "V1"}
} }
// TODO: detect whether string may have to be normalized in the following
// loop.
for i := 0; i < len(s); { for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:]) v, sz := trie.lookupString(s[i:])
if sz == 0 {
return runeError(utf8.RuneError)
}
if c := p.simplify(info(v).category()); c != valid && c != deviation { if c := p.simplify(info(v).category()); c != valid && c != deviation {
return &labelError{s, "V6"} return &labelError{s, "V6"}
} }
@ -604,16 +671,13 @@ var joinStates = [][numJoinTypes]joinState{
// validateLabel validates the criteria from Section 4.1. Item 1, 4, and 6 are // validateLabel validates the criteria from Section 4.1. Item 1, 4, and 6 are
// already implicitly satisfied by the overall implementation. // already implicitly satisfied by the overall implementation.
func (p *Profile) validateLabel(s string) error { func (p *Profile) validateLabel(s string) (err error) {
if s == "" { if s == "" {
if p.verifyDNSLength { if p.verifyDNSLength {
return &labelError{s, "A4"} return &labelError{s, "A4"}
} }
return nil return nil
} }
if p.bidirule != nil && !p.bidirule(s) {
return &labelError{s, "B"}
}
if !p.validateLabels { if !p.validateLabels {
return nil return nil
} }

View File

@ -39,5 +39,70 @@ func TestIDNA(t *testing.T) {
} }
} }
func TestIDNASeparators(t *testing.T) {
type subCase struct {
unicode string
wantASCII string
wantErr bool
}
testCases := []struct {
name string
profile *Profile
subCases []subCase
}{
{
name: "Punycode", profile: Punycode,
subCases: []subCase{
{"example\u3002jp", "xn--examplejp-ck3h", false},
{"東京\uFF0Ejp", "xn--jp-l92cn98g071o", false},
{"大阪\uFF61jp", "xn--jp-ku9cz72u463f", false},
},
},
{
name: "Lookup", profile: Lookup,
subCases: []subCase{
{"example\u3002jp", "example.jp", false},
{"東京\uFF0Ejp", "xn--1lqs71d.jp", false},
{"大阪\uFF61jp", "xn--pssu33l.jp", false},
},
},
{
name: "Display", profile: Display,
subCases: []subCase{
{"example\u3002jp", "example.jp", false},
{"東京\uFF0Ejp", "xn--1lqs71d.jp", false},
{"大阪\uFF61jp", "xn--pssu33l.jp", false},
},
},
{
name: "Registration", profile: Registration,
subCases: []subCase{
{"example\u3002jp", "", true},
{"東京\uFF0Ejp", "", true},
{"大阪\uFF61jp", "", true},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, c := range tc.subCases {
gotA, err := tc.profile.ToASCII(c.unicode)
if c.wantErr {
if err == nil {
t.Errorf("ToASCII(%q): got no error, but an error expected", c.unicode)
}
} else {
if err != nil {
t.Errorf("ToASCII(%q): got err=%v, but no error expected", c.unicode, err)
} else if gotA != c.wantASCII {
t.Errorf("ToASCII(%q): got %q, want %q", c.unicode, gotA, c.wantASCII)
}
}
}
})
}
}
// TODO(nigeltao): test errors, once we've specified when ToASCII and ToUnicode // TODO(nigeltao): test errors, once we've specified when ToASCII and ToUnicode
// return errors. // return errors.

4396
vendor/golang.org/x/net/idna/tables.go generated vendored

File diff suppressed because it is too large Load Diff

View File

@ -26,9 +26,9 @@ package idna
// 15..3 index into xor or mapping table // 15..3 index into xor or mapping table
// } // }
// } else { // } else {
// 15..13 unused // 15..14 unused
// 12 modifier (including virama) // 13 mayNeedNorm
// 11 virama modifier // 12..11 attributes
// 10..8 joining type // 10..8 joining type
// 7..3 category type // 7..3 category type
// } // }
@ -49,15 +49,20 @@ const (
joinShift = 8 joinShift = 8
joinMask = 0x07 joinMask = 0x07
viramaModifier = 0x0800 // Attributes
attributesMask = 0x1800
viramaModifier = 0x1800
modifier = 0x1000 modifier = 0x1000
rtl = 0x0800
mayNeedNorm = 0x2000
) )
// A category corresponds to a category defined in the IDNA mapping table. // A category corresponds to a category defined in the IDNA mapping table.
type category uint16 type category uint16
const ( const (
unknown category = 0 // not defined currently in unicode. unknown category = 0 // not currently defined in unicode.
mapped category = 1 mapped category = 1
disallowedSTD3Mapped category = 2 disallowedSTD3Mapped category = 2
deviation category = 3 deviation category = 3
@ -110,5 +115,5 @@ func (c info) isModifier() bool {
} }
func (c info) isViramaModifier() bool { func (c info) isViramaModifier() bool {
return c&(viramaModifier|catSmallMask) == viramaModifier return c&(attributesMask|catSmallMask) == viramaModifier
} }

View File

@ -1,5 +1,5 @@
// go generate gen.go // go generate gen.go
// GENERATED BY THE COMMAND ABOVE; DO NOT EDIT // Code generated by the command above; DO NOT EDIT.
// Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA). // Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA).
package iana // import "golang.org/x/net/internal/iana" package iana // import "golang.org/x/net/internal/iana"
@ -38,7 +38,7 @@ const (
CongestionExperienced = 0x3 // CE (Congestion Experienced) CongestionExperienced = 0x3 // CE (Congestion Experienced)
) )
// Protocol Numbers, Updated: 2016-06-22 // Protocol Numbers, Updated: 2017-10-13
const ( const (
ProtocolIP = 0 // IPv4 encapsulation, pseudo protocol number ProtocolIP = 0 // IPv4 encapsulation, pseudo protocol number
ProtocolHOPOPT = 0 // IPv6 Hop-by-Hop Option ProtocolHOPOPT = 0 // IPv6 Hop-by-Hop Option
@ -178,3 +178,50 @@ const (
ProtocolROHC = 142 // Robust Header Compression ProtocolROHC = 142 // Robust Header Compression
ProtocolReserved = 255 // Reserved ProtocolReserved = 255 // Reserved
) )
// Address Family Numbers, Updated: 2016-10-25
const (
AddrFamilyIPv4 = 1 // IP (IP version 4)
AddrFamilyIPv6 = 2 // IP6 (IP version 6)
AddrFamilyNSAP = 3 // NSAP
AddrFamilyHDLC = 4 // HDLC (8-bit multidrop)
AddrFamilyBBN1822 = 5 // BBN 1822
AddrFamily802 = 6 // 802 (includes all 802 media plus Ethernet "canonical format")
AddrFamilyE163 = 7 // E.163
AddrFamilyE164 = 8 // E.164 (SMDS, Frame Relay, ATM)
AddrFamilyF69 = 9 // F.69 (Telex)
AddrFamilyX121 = 10 // X.121 (X.25, Frame Relay)
AddrFamilyIPX = 11 // IPX
AddrFamilyAppletalk = 12 // Appletalk
AddrFamilyDecnetIV = 13 // Decnet IV
AddrFamilyBanyanVines = 14 // Banyan Vines
AddrFamilyE164withSubaddress = 15 // E.164 with NSAP format subaddress
AddrFamilyDNS = 16 // DNS (Domain Name System)
AddrFamilyDistinguishedName = 17 // Distinguished Name
AddrFamilyASNumber = 18 // AS Number
AddrFamilyXTPoverIPv4 = 19 // XTP over IP version 4
AddrFamilyXTPoverIPv6 = 20 // XTP over IP version 6
AddrFamilyXTPnativemodeXTP = 21 // XTP native mode XTP
AddrFamilyFibreChannelWorldWidePortName = 22 // Fibre Channel World-Wide Port Name
AddrFamilyFibreChannelWorldWideNodeName = 23 // Fibre Channel World-Wide Node Name
AddrFamilyGWID = 24 // GWID
AddrFamilyL2VPN = 25 // AFI for L2VPN information
AddrFamilyMPLSTPSectionEndpointID = 26 // MPLS-TP Section Endpoint Identifier
AddrFamilyMPLSTPLSPEndpointID = 27 // MPLS-TP LSP Endpoint Identifier
AddrFamilyMPLSTPPseudowireEndpointID = 28 // MPLS-TP Pseudowire Endpoint Identifier
AddrFamilyMTIPv4 = 29 // MT IP: Multi-Topology IP version 4
AddrFamilyMTIPv6 = 30 // MT IPv6: Multi-Topology IP version 6
AddrFamilyEIGRPCommonServiceFamily = 16384 // EIGRP Common Service Family
AddrFamilyEIGRPIPv4ServiceFamily = 16385 // EIGRP IPv4 Service Family
AddrFamilyEIGRPIPv6ServiceFamily = 16386 // EIGRP IPv6 Service Family
AddrFamilyLISPCanonicalAddressFormat = 16387 // LISP Canonical Address Format (LCAF)
AddrFamilyBGPLS = 16388 // BGP-LS
AddrFamily48bitMAC = 16389 // 48-bit MAC
AddrFamily64bitMAC = 16390 // 64-bit MAC
AddrFamilyOUI = 16391 // OUI
AddrFamilyMACFinal24bits = 16392 // MAC/24
AddrFamilyMACFinal40bits = 16393 // MAC/40
AddrFamilyIPv6Initial64bits = 16394 // IPv6/64
AddrFamilyRBridgePortID = 16395 // RBridge Port ID
AddrFamilyTRILLNickname = 16396 // TRILL Nickname
)

View File

@ -28,23 +28,27 @@ var registries = []struct {
parse func(io.Writer, io.Reader) error parse func(io.Writer, io.Reader) error
}{ }{
{ {
"http://www.iana.org/assignments/dscp-registry/dscp-registry.xml", "https://www.iana.org/assignments/dscp-registry/dscp-registry.xml",
parseDSCPRegistry, parseDSCPRegistry,
}, },
{ {
"http://www.iana.org/assignments/ipv4-tos-byte/ipv4-tos-byte.xml", "https://www.iana.org/assignments/ipv4-tos-byte/ipv4-tos-byte.xml",
parseTOSTCByte, parseTOSTCByte,
}, },
{ {
"http://www.iana.org/assignments/protocol-numbers/protocol-numbers.xml", "https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xml",
parseProtocolNumbers, parseProtocolNumbers,
}, },
{
"http://www.iana.org/assignments/address-family-numbers/address-family-numbers.xml",
parseAddrFamilyNumbers,
},
} }
func main() { func main() {
var bb bytes.Buffer var bb bytes.Buffer
fmt.Fprintf(&bb, "// go generate gen.go\n") fmt.Fprintf(&bb, "// go generate gen.go\n")
fmt.Fprintf(&bb, "// GENERATED BY THE COMMAND ABOVE; DO NOT EDIT\n\n") fmt.Fprintf(&bb, "// Code generated by the command above; DO NOT EDIT.\n\n")
fmt.Fprintf(&bb, "// Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA).\n") fmt.Fprintf(&bb, "// Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA).\n")
fmt.Fprintf(&bb, `package iana // import "golang.org/x/net/internal/iana"`+"\n\n") fmt.Fprintf(&bb, `package iana // import "golang.org/x/net/internal/iana"`+"\n\n")
for _, r := range registries { for _, r := range registries {
@ -291,3 +295,93 @@ func (pn *protocolNumbers) escape() []canonProtocolRecord {
} }
return prs return prs
} }
func parseAddrFamilyNumbers(w io.Writer, r io.Reader) error {
dec := xml.NewDecoder(r)
var afn addrFamilylNumbers
if err := dec.Decode(&afn); err != nil {
return err
}
afrs := afn.escape()
fmt.Fprintf(w, "// %s, Updated: %s\n", afn.Title, afn.Updated)
fmt.Fprintf(w, "const (\n")
for _, afr := range afrs {
if afr.Name == "" {
continue
}
fmt.Fprintf(w, "AddrFamily%s = %d", afr.Name, afr.Value)
fmt.Fprintf(w, "// %s\n", afr.Descr)
}
fmt.Fprintf(w, ")\n")
return nil
}
type addrFamilylNumbers struct {
XMLName xml.Name `xml:"registry"`
Title string `xml:"title"`
Updated string `xml:"updated"`
RegTitle string `xml:"registry>title"`
Note string `xml:"registry>note"`
Records []struct {
Value string `xml:"value"`
Descr string `xml:"description"`
} `xml:"registry>record"`
}
type canonAddrFamilyRecord struct {
Name string
Descr string
Value int
}
func (afn *addrFamilylNumbers) escape() []canonAddrFamilyRecord {
afrs := make([]canonAddrFamilyRecord, len(afn.Records))
sr := strings.NewReplacer(
"IP version 4", "IPv4",
"IP version 6", "IPv6",
"Identifier", "ID",
"-", "",
"-", "",
"/", "",
".", "",
" ", "",
)
for i, afr := range afn.Records {
if strings.Contains(afr.Descr, "Unassigned") ||
strings.Contains(afr.Descr, "Reserved") {
continue
}
afrs[i].Descr = afr.Descr
s := strings.TrimSpace(afr.Descr)
switch s {
case "IP (IP version 4)":
afrs[i].Name = "IPv4"
case "IP6 (IP version 6)":
afrs[i].Name = "IPv6"
case "AFI for L2VPN information":
afrs[i].Name = "L2VPN"
case "E.164 with NSAP format subaddress":
afrs[i].Name = "E164withSubaddress"
case "MT IP: Multi-Topology IP version 4":
afrs[i].Name = "MTIPv4"
case "MAC/24":
afrs[i].Name = "MACFinal24bits"
case "MAC/40":
afrs[i].Name = "MACFinal40bits"
case "IPv6/64":
afrs[i].Name = "IPv6Initial64bits"
default:
n := strings.Index(s, "(")
if n > 0 {
s = s[:n]
}
n = strings.Index(s, ":")
if n > 0 {
s = s[:n]
}
afrs[i].Name = sr.Replace(s)
}
afrs[i].Value, _ = strconv.Atoi(afr.Value)
}
return afrs
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build nacl plan9 // +build js,wasm nacl plan9
package nettest package nettest

View File

@ -64,7 +64,7 @@ func TestableNetwork(network string) bool {
switch network { switch network {
case "unix", "unixgram": case "unix", "unixgram":
switch runtime.GOOS { switch runtime.GOOS {
case "android", "nacl", "plan9", "windows": case "android", "js", "nacl", "plan9", "windows":
return false return false
} }
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
@ -72,8 +72,13 @@ func TestableNetwork(network string) bool {
} }
case "unixpacket": case "unixpacket":
switch runtime.GOOS { switch runtime.GOOS {
case "android", "darwin", "freebsd", "nacl", "plan9", "windows": case "android", "darwin", "freebsd", "js", "nacl", "plan9", "windows":
return false return false
case "netbsd":
// It passes on amd64 at least. 386 fails (Issue 22927). arm is unknown.
if runtime.GOARCH == "386" {
return false
}
} }
} }
return true return true

View File

@ -10,6 +10,10 @@ package socket
import "unsafe" import "unsafe"
func (v *iovec) set(b []byte) { func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*byte)(unsafe.Pointer(&b[0])) v.Base = (*byte)(unsafe.Pointer(&b[0]))
v.Len = uint32(len(b)) v.Len = uint32(l)
} }

View File

@ -10,6 +10,10 @@ package socket
import "unsafe" import "unsafe"
func (v *iovec) set(b []byte) { func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*byte)(unsafe.Pointer(&b[0])) v.Base = (*byte)(unsafe.Pointer(&b[0]))
v.Len = uint64(len(b)) v.Len = uint64(l)
} }

View File

@ -10,6 +10,10 @@ package socket
import "unsafe" import "unsafe"
func (v *iovec) set(b []byte) { func (v *iovec) set(b []byte) {
l := len(b)
if l == 0 {
return
}
v.Base = (*int8)(unsafe.Pointer(&b[0])) v.Base = (*int8)(unsafe.Pointer(&b[0]))
v.Len = uint64(len(b)) v.Len = uint64(l)
} }

View File

@ -7,6 +7,10 @@
package socket package socket
func (h *msghdr) setIov(vs []iovec) { func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0] h.Iov = &vs[0]
h.Iovlen = int32(len(vs)) h.Iovlen = int32(l)
} }

View File

@ -10,8 +10,12 @@ package socket
import "unsafe" import "unsafe"
func (h *msghdr) setIov(vs []iovec) { func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0] h.Iov = &vs[0]
h.Iovlen = uint32(len(vs)) h.Iovlen = uint32(l)
} }
func (h *msghdr) setControl(b []byte) { func (h *msghdr) setControl(b []byte) {

View File

@ -10,8 +10,12 @@ package socket
import "unsafe" import "unsafe"
func (h *msghdr) setIov(vs []iovec) { func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0] h.Iov = &vs[0]
h.Iovlen = uint64(len(vs)) h.Iovlen = uint64(l)
} }
func (h *msghdr) setControl(b []byte) { func (h *msghdr) setControl(b []byte) {

View File

@ -5,6 +5,10 @@
package socket package socket
func (h *msghdr) setIov(vs []iovec) { func (h *msghdr) setIov(vs []iovec) {
l := len(vs)
if l == 0 {
return
}
h.Iov = &vs[0] h.Iov = &vs[0]
h.Iovlen = uint32(len(vs)) h.Iovlen = uint32(l)
} }

View File

@ -13,8 +13,10 @@ func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
for i := range vs { for i := range vs {
vs[i].set(bs[i]) vs[i].set(bs[i])
} }
h.Iov = &vs[0] if len(vs) > 0 {
h.Iovlen = int32(len(vs)) h.Iov = &vs[0]
h.Iovlen = int32(len(vs))
}
if len(oob) > 0 { if len(oob) > 0 {
h.Accrights = (*int8)(unsafe.Pointer(&oob[0])) h.Accrights = (*int8)(unsafe.Pointer(&oob[0]))
h.Accrightslen = int32(len(oob)) h.Accrightslen = int32(len(oob))

View File

@ -110,7 +110,7 @@ func ControlMessageSpace(dataLen int) int {
type ControlMessage []byte type ControlMessage []byte
// Data returns the data field of the control message at the head on // Data returns the data field of the control message at the head on
// w. // m.
func (m ControlMessage) Data(dataLen int) []byte { func (m ControlMessage) Data(dataLen int) []byte {
l := controlHeaderLen() l := controlHeaderLen()
if len(m) < l || len(m) < l+dataLen { if len(m) < l || len(m) < l+dataLen {
@ -119,7 +119,7 @@ func (m ControlMessage) Data(dataLen int) []byte {
return m[l : l+dataLen] return m[l : l+dataLen]
} }
// Next returns the control message at the next on w. // Next returns the control message at the next on m.
// //
// Next works only for standard control messages. // Next works only for standard control messages.
func (m ControlMessage) Next(dataLen int) ControlMessage { func (m ControlMessage) Next(dataLen int) ControlMessage {
@ -131,7 +131,7 @@ func (m ControlMessage) Next(dataLen int) ControlMessage {
} }
// MarshalHeader marshals the header fields of the control message at // MarshalHeader marshals the header fields of the control message at
// the head on w. // the head on m.
func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error { func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error {
if len(m) < controlHeaderLen() { if len(m) < controlHeaderLen() {
return errors.New("short message") return errors.New("short message")
@ -142,7 +142,7 @@ func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error {
} }
// ParseHeader parses and returns the header fields of the control // ParseHeader parses and returns the header fields of the control
// message at the head on w. // message at the head on m.
func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) { func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) {
l := controlHeaderLen() l := controlHeaderLen()
if len(m) < l { if len(m) < l {
@ -152,7 +152,7 @@ func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) {
return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil
} }
// Marshal marshals the control message at the head on w, and returns // Marshal marshals the control message at the head on m, and returns
// the next control message. // the next control message.
func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) { func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) {
l := len(data) l := len(data)
@ -167,7 +167,7 @@ func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, erro
return m.Next(l), nil return m.Next(l), nil
} }
// Parse parses w as a single or multiple control messages. // Parse parses m as a single or multiple control messages.
// //
// Parse works for both standard and compatible messages. // Parse works for both standard and compatible messages.
func (m ControlMessage) Parse() ([]ControlMessage, error) { func (m ControlMessage) Parse() ([]ControlMessage, error) {
@ -175,6 +175,9 @@ func (m ControlMessage) Parse() ([]ControlMessage, error) {
for len(m) >= controlHeaderLen() { for len(m) >= controlHeaderLen() {
h := (*cmsghdr)(unsafe.Pointer(&m[0])) h := (*cmsghdr)(unsafe.Pointer(&m[0]))
l := h.len() l := h.len()
if l <= 0 {
return nil, errors.New("invalid header length")
}
if uint64(l) < uint64(controlHeaderLen()) { if uint64(l) < uint64(controlHeaderLen()) {
return nil, errors.New("invalid message length") return nil, errors.New("invalid message length")
} }

View File

@ -119,81 +119,84 @@ func TestUDP(t *testing.T) {
t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
} }
defer c.Close() defer c.Close()
cc, err := socket.NewConn(c.(net.Conn))
if err != nil {
t.Fatal(err)
}
t.Run("Message", func(t *testing.T) { t.Run("Message", func(t *testing.T) {
testUDPMessage(t, c.(net.Conn)) data := []byte("HELLO-R-U-THERE")
wm := socket.Message{
Buffers: bytes.SplitAfter(data, []byte("-")),
Addr: c.LocalAddr(),
}
if err := cc.SendMsg(&wm, 0); err != nil {
t.Fatal(err)
}
b := make([]byte, 32)
rm := socket.Message{
Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
}
if err := cc.RecvMsg(&rm, 0); err != nil {
t.Fatal(err)
}
if !bytes.Equal(b[:rm.N], data) {
t.Fatalf("got %#v; want %#v", b[:rm.N], data)
}
}) })
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case "android", "linux":
t.Run("Messages", func(t *testing.T) { t.Run("Messages", func(t *testing.T) {
testUDPMessages(t, c.(net.Conn)) data := []byte("HELLO-R-U-THERE")
wmbs := bytes.SplitAfter(data, []byte("-"))
wms := []socket.Message{
{Buffers: wmbs[:1], Addr: c.LocalAddr()},
{Buffers: wmbs[1:], Addr: c.LocalAddr()},
}
n, err := cc.SendMsgs(wms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(wms) {
t.Fatalf("got %d; want %d", n, len(wms))
}
b := make([]byte, 32)
rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
rms := []socket.Message{
{Buffers: rmbs[0]},
{Buffers: rmbs[1]},
}
n, err = cc.RecvMsgs(rms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(rms) {
t.Fatalf("got %d; want %d", n, len(rms))
}
nn := 0
for i := 0; i < n; i++ {
nn += rms[i].N
}
if !bytes.Equal(b[:nn], data) {
t.Fatalf("got %#v; want %#v", b[:nn], data)
}
}) })
} }
}
func testUDPMessage(t *testing.T, c net.Conn) { // The behavior of transmission for zero byte paylaod depends
cc, err := socket.NewConn(c) // on each platform implementation. Some may transmit only
if err != nil { // protocol header and options, other may transmit nothing.
t.Fatal(err) // We test only that SendMsg and SendMsgs will not crash with
} // empty buffers.
data := []byte("HELLO-R-U-THERE")
wm := socket.Message{ wm := socket.Message{
Buffers: bytes.SplitAfter(data, []byte("-")), Buffers: [][]byte{{}},
Addr: c.LocalAddr(), Addr: c.LocalAddr(),
} }
if err := cc.SendMsg(&wm, 0); err != nil { cc.SendMsg(&wm, 0)
t.Fatal(err)
}
b := make([]byte, 32)
rm := socket.Message{
Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
}
if err := cc.RecvMsg(&rm, 0); err != nil {
t.Fatal(err)
}
if !bytes.Equal(b[:rm.N], data) {
t.Fatalf("got %#v; want %#v", b[:rm.N], data)
}
}
func testUDPMessages(t *testing.T, c net.Conn) {
cc, err := socket.NewConn(c)
if err != nil {
t.Fatal(err)
}
data := []byte("HELLO-R-U-THERE")
wmbs := bytes.SplitAfter(data, []byte("-"))
wms := []socket.Message{ wms := []socket.Message{
{Buffers: wmbs[:1], Addr: c.LocalAddr()}, {Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
{Buffers: wmbs[1:], Addr: c.LocalAddr()},
}
n, err := cc.SendMsgs(wms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(wms) {
t.Fatalf("got %d; want %d", n, len(wms))
}
b := make([]byte, 32)
rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
rms := []socket.Message{
{Buffers: rmbs[0]},
{Buffers: rmbs[1]},
}
n, err = cc.RecvMsgs(rms, 0)
if err != nil {
t.Fatal(err)
}
if n != len(rms) {
t.Fatalf("got %d; want %d", n, len(rms))
}
nn := 0
for i := 0; i < n; i++ {
nn += rms[i].N
}
if !bytes.Equal(b[:nn], data) {
t.Fatalf("got %#v; want %#v", b[:nn], data)
} }
cc.SendMsgs(wms, 0)
} }
func BenchmarkUDP(b *testing.B) { func BenchmarkUDP(b *testing.B) {
@ -230,7 +233,7 @@ func BenchmarkUDP(b *testing.B) {
} }
}) })
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case "android", "linux":
wms := make([]socket.Message, M) wms := make([]socket.Message, M)
for i := range wms { for i := range wms {
wms[i].Buffers = [][]byte{data} wms[i].Buffers = [][]byte{data}

View File

@ -34,7 +34,7 @@ func marshalSockaddr(ip net.IP, port int, zone string) []byte {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
b := make([]byte, sizeofSockaddrInet) b := make([]byte, sizeofSockaddrInet)
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "solaris", "windows": case "android", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET)) NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
default: default:
b[0] = sizeofSockaddrInet b[0] = sizeofSockaddrInet
@ -47,7 +47,7 @@ func marshalSockaddr(ip net.IP, port int, zone string) []byte {
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil { if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
b := make([]byte, sizeofSockaddrInet6) b := make([]byte, sizeofSockaddrInet6)
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "solaris", "windows": case "android", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6)) NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
default: default:
b[0] = sizeofSockaddrInet6 b[0] = sizeofSockaddrInet6
@ -69,7 +69,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
} }
var af int var af int
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "solaris", "windows": case "android", "linux", "solaris", "windows":
af = int(NativeEndian.Uint16(b[:2])) af = int(NativeEndian.Uint16(b[:2]))
default: default:
af = int(b[1]) af = int(b[1])

View File

@ -0,0 +1,61 @@
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs defs_darwin.go
package socket
const (
sysAF_UNSPEC = 0x0
sysAF_INET = 0x2
sysAF_INET6 = 0x1e
sysSOCK_RAW = 0x3
)
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad_cgo_0 [4]byte
Iov *iovec
Iovlen int32
Pad_cgo_1 [4]byte
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
}
type sockaddrInet struct {
Len uint8
Family uint8
Port uint16
Addr [4]byte /* in_addr */
Zero [8]int8
}
type sockaddrInet6 struct {
Len uint8
Family uint8
Port uint16
Flowinfo uint32
Addr [16]byte /* in6_addr */
Scope_id uint32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
sizeofCmsghdr = 0xc
sizeofSockaddrInet = 0x10
sizeofSockaddrInet6 = 0x1c
)

View File

@ -26,6 +26,11 @@ type msghdr struct {
Flags int32 Flags int32
} }
type mmsghdr struct {
Hdr msghdr
Len uint32
}
type cmsghdr struct { type cmsghdr struct {
Len uint32 Len uint32
Level int32 Level int32
@ -52,6 +57,7 @@ type sockaddrInet6 struct {
const ( const (
sizeofIovec = 0x8 sizeofIovec = 0x8
sizeofMsghdr = 0x1c sizeofMsghdr = 0x1c
sizeofMmsghdr = 0x20
sizeofCmsghdr = 0xc sizeofCmsghdr = 0xc
sizeofSockaddrInet = 0x10 sizeofSockaddrInet = 0x10

168
vendor/golang.org/x/net/internal/socks/client.go generated vendored Normal file
View File

@ -0,0 +1,168 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
"time"
)
var (
noDeadline = time.Time{}
aLongTimeAgo = time.Unix(1, 0)
)
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := splitHostPort(address)
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
}
if ctx != context.Background() {
errCh := make(chan error, 1)
done := make(chan struct{})
defer func() {
close(done)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
c.SetDeadline(aLongTimeAgo)
errCh <- ctx.Err()
case <-done:
errCh <- nil
}
}()
}
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, Version5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(AuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
b = append(b, byte(am))
}
}
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := AuthMethod(b[1])
if am == AuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
return
}
}
b = b[:0]
b = append(b, Version5, byte(d.cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, AddrTypeIPv4)
b = append(b, ip4...)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, AddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
}
b = append(b, AddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
}
l := 2
var a Addr
switch b[3] {
case AddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
case AddrTypeIPv6:
l += net.IPv6len
a.IP = make(net.IP, net.IPv6len)
case AddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}
if cap(b) < l {
b = make([]byte, l)
} else {
b = b[:l]
}
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
return
}
if a.IP != nil {
copy(a.IP, b)
} else {
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}
func splitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
}
return host, portnum, nil
}

158
vendor/golang.org/x/net/internal/socks/dial_test.go generated vendored Normal file
View File

@ -0,0 +1,158 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socks_test
import (
"context"
"io"
"math/rand"
"net"
"os"
"testing"
"time"
"golang.org/x/net/internal/socks"
"golang.org/x/net/internal/sockstest"
)
const (
targetNetwork = "tcp6"
targetHostname = "fqdn.doesnotexist"
targetHostIP = "2001:db8::1"
targetPort = "5963"
)
func TestDial(t *testing.T) {
t.Run("Connect", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Error(err)
return
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
d.AuthMethods = []socks.AuthMethod{
socks.AuthMethodNotRequired,
socks.AuthMethodUsernamePassword,
}
d.Authenticate = (&socks.UsernamePassword{
Username: "username",
Password: "password",
}).Authenticate
c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
if err == nil {
c.(*socks.Conn).BoundAddr()
c.Close()
}
if err != nil {
t.Error(err)
return
}
})
t.Run("Cancel", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil {
t.Error(err)
return
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dialErr := make(chan error)
go func() {
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
if err == nil {
c.Close()
}
dialErr <- err
}()
time.Sleep(100 * time.Millisecond)
cancel()
err = <-dialErr
if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
t.Errorf("got %v; want context.Canceled or equivalent", err)
return
}
})
t.Run("Deadline", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil {
t.Error(err)
return
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
if err == nil {
c.Close()
}
if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
return
}
})
t.Run("WithRogueServer", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
if err != nil {
t.Error(err)
return
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
for i := 0; i < 2*len(rogueCmdList); i++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
if err == nil {
t.Log(c.(*socks.Conn).BoundAddr())
c.Close()
t.Error("should fail")
}
}
})
}
func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
if _, err := sockstest.ParseCmdRequest(b); err != nil {
return err
}
var bb [1]byte
for {
if _, err := rw.Read(bb[:]); err != nil {
return err
}
}
}
func rogueCmdFunc(rw io.ReadWriter, b []byte) error {
if _, err := sockstest.ParseCmdRequest(b); err != nil {
return err
}
rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))])
return nil
}
var rogueCmdList = [][]byte{
{0x05},
{0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
{0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b},
{0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3},
{0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'},
}
func parseDialError(err error) (perr, nerr error) {
if e, ok := err.(*net.OpError); ok {
err = e.Err
nerr = e
}
if e, ok := err.(*os.SyscallError); ok {
err = e.Err
}
perr = err
return
}

265
vendor/golang.org/x/net/internal/socks/socks.go generated vendored Normal file
View File

@ -0,0 +1,265 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package socks provides a SOCKS version 5 client implementation.
//
// SOCKS protocol version 5 is defined in RFC 1928.
// Username/Password authentication for SOCKS version 5 is defined in
// RFC 1929.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
)
// A Command represents a SOCKS command.
type Command int
func (cmd Command) String() string {
switch cmd {
case CmdConnect:
return "socks connect"
case cmdBind:
return "socks bind"
default:
return "socks " + strconv.Itoa(int(cmd))
}
}
// An AuthMethod represents a SOCKS authentication method.
type AuthMethod int
// A Reply represents a SOCKS command reply code.
type Reply int
func (code Reply) String() string {
switch code {
case StatusSucceeded:
return "succeeded"
case 0x01:
return "general SOCKS server failure"
case 0x02:
return "connection not allowed by ruleset"
case 0x03:
return "network unreachable"
case 0x04:
return "host unreachable"
case 0x05:
return "connection refused"
case 0x06:
return "TTL expired"
case 0x07:
return "command not supported"
case 0x08:
return "address type not supported"
default:
return "unknown code: " + strconv.Itoa(int(code))
}
}
// Wire protocol constants.
const (
Version5 = 0x05
AddrTypeIPv4 = 0x01
AddrTypeFQDN = 0x03
AddrTypeIPv6 = 0x04
CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authetication methods
StatusSucceeded Reply = 0x00
)
// An Addr represents a SOCKS-specific address.
// Either Name or IP is used exclusively.
type Addr struct {
Name string // fully-qualified domain name
IP net.IP
Port int
}
func (a *Addr) Network() string { return "socks" }
func (a *Addr) String() string {
if a == nil {
return "<nil>"
}
port := strconv.Itoa(a.Port)
if a.IP == nil {
return net.JoinHostPort(a.Name, port)
}
return net.JoinHostPort(a.IP.String(), port)
}
// A Conn represents a forward proxy connection.
type Conn struct {
net.Conn
boundAddr net.Addr
}
// BoundAddr returns the address assigned by the proxy server for
// connecting to the command target address from the proxy server.
func (c *Conn) BoundAddr() net.Addr {
if c == nil {
return nil
}
return c.boundAddr
}
// A Dialer holds SOCKS-specific options.
type Dialer struct {
cmd Command // either CmdConnect or cmdBind
proxyNetwork string // network between a proxy server and a client
proxyAddress string // proxy server address
// ProxyDial specifies the optional dial function for
// establishing the transport connection.
ProxyDial func(context.Context, string, string) (net.Conn, error)
// AuthMethods specifies the list of request authention
// methods.
// If empty, SOCKS client requests only AuthMethodNotRequired.
AuthMethods []AuthMethod
// Authenticate specifies the optional authentication
// function. It must be non-nil when AuthMethods is not empty.
// It must return an error when the authentication is failed.
Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
}
// DialContext connects to the provided address on the provided
// network.
//
// The returned error value may be a net.OpError. When the Op field of
// net.OpError contains "socks", the Source field contains a proxy
// server address and the Addr field contains a command target
// address.
//
// See func Dial of the net package of standard library for a
// description of the network and address parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
}
switch d.cmd {
case CmdConnect, cmdBind:
default:
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")}
}
if ctx == nil {
ctx = context.Background()
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
} else {
var dd net.Dialer
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
a, err := d.connect(ctx, c, address)
if err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return &Conn{Conn: c, boundAddr: a}, nil
}
// Dial connects to the provided address on the provided network.
//
// Deprecated: Use DialContext instead.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
for i, s := range []string{d.proxyAddress, address} {
host, port, err := splitHostPort(s)
if err != nil {
return nil, nil, err
}
a := &Addr{Port: port}
a.IP = net.ParseIP(host)
if a.IP == nil {
a.Name = host
}
if i == 0 {
proxy = a
} else {
dst = a
}
}
return
}
// NewDialer returns a new Dialer that dials through the provided
// proxy server's network and address.
func NewDialer(network, address string) *Dialer {
return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
}
const (
authUsernamePasswordVersion = 0x01
authStatusSucceeded = 0x00
)
// UsernamePassword are the credentials for the username/password
// authentication method.
type UsernamePassword struct {
Username string
Password string
}
// Authenticate authenticates a pair of username and password with the
// proxy server.
func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
switch auth {
case AuthMethodNotRequired:
return nil
case AuthMethodUsernamePassword:
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
return errors.New("invalid username/password")
}
b := []byte{authUsernamePasswordVersion}
b = append(b, byte(len(up.Username)))
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// necessary
if _, err := rw.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(rw, b[:2]); err != nil {
return err
}
if b[0] != authUsernamePasswordVersion {
return errors.New("invalid username/password version")
}
if b[1] != authStatusSucceeded {
return errors.New("username/password authentication failed")
}
return nil
}
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
}

Some files were not shown because too many files have changed in this diff Show More