Netlink sockets can't be shutdown

This commit is contained in:
Jason A. Donenfeld 2018-05-14 14:08:03 +02:00
parent 2dfd4e7d8c
commit 795f76cffa
4 changed files with 69 additions and 25 deletions

View File

@ -15,6 +15,7 @@
package main package main
import ( import (
"./rwcancel"
"errors" "errors"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
@ -55,10 +56,11 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
} }
type NativeBind struct { type NativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
netlinkSock int netlinkSock int
lastMark uint32 netlinkCancel *rwcancel.RWCancel
lastMark uint32
} }
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
@ -125,18 +127,23 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device) go bind.routineRouteListener(device)
bind.sock6, port, err = create6(port) bind.sock6, port, err = create6(port)
if err != nil { if err != nil {
unix.Close(bind.netlinkSock) bind.netlinkCancel.Cancel()
return nil, port, err return nil, port, err
} }
bind.sock4, port, err = create4(port) bind.sock4, port, err = create4(port)
if err != nil { if err != nil {
unix.Close(bind.netlinkSock) bind.netlinkCancel.Cancel()
unix.Close(bind.sock6) unix.Close(bind.sock6)
} }
return &bind, port, err return &bind, port, err
@ -178,7 +185,8 @@ func closeUnblock(fd int) error {
func (bind *NativeBind) Close() error { func (bind *NativeBind) Close() error {
err1 := closeUnblock(bind.sock6) err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4) err2 := closeUnblock(bind.sock4)
err3 := closeUnblock(bind.netlinkSock) err3 := bind.netlinkCancel.Cancel()
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@ -539,8 +547,20 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func (bind *NativeBind) routineRouteListener(device *Device) { func (bind *NativeBind) routineRouteListener(device *Device) {
var reqPeer map[uint32]*Peer var reqPeer map[uint32]*Peer
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil { if err != nil {
return return
} }

View File

@ -221,14 +221,10 @@ func main() {
return return
} }
// create wireguard device
device := NewDevice(tun, logger) device := NewDevice(tun, logger)
logger.Info.Println("Device started") logger.Info.Println("Device started")
// start uapi listener
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal) term := make(chan os.Signal)

View File

@ -122,11 +122,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
_, err := tun.Name() _, err := tun.Name()
if err != nil { if err != nil {
tun.fd.Close()
return nil, err return nil, err
} }
tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd())) tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil { if err != nil {
tun.fd.Close()
return nil, err return nil, err
} }

View File

@ -31,14 +31,16 @@ const (
) )
type NativeTun struct { type NativeTun struct {
fd *os.File fd *os.File
index int32 // if index fdCancel *rwcancel.RWCancel
name string // name of interface index int32 // if index
errors chan error // async error handling name string // name of interface
events chan TUNEvent // device related events errors chan error // async error handling
nopi bool // the device was pased IFF_NO_PI events chan TUNEvent // device related events
rwcancel *rwcancel.RWCancel nopi bool // the device was pased IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
} }
@ -86,9 +88,22 @@ func createNetlinkSocket() (int, error) {
} }
func (tun *NativeTun) RoutineNetlinkListener() { func (tun *NativeTun) RoutineNetlinkListener() {
defer unix.Close(tun.netlinkSock)
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
break
}
if !tun.netlinkCancel.ReadyRead() {
tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error())
return
}
}
if err != nil { if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error()) tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
return return
@ -323,7 +338,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
if err == nil || !rwcancel.ErrorIsEAGAIN(err) { if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
return n, err return n, err
} }
if !tun.rwcancel.ReadyRead() { if !tun.fdCancel.ReadyRead() {
return 0, errors.New("tun device closed") return 0, errors.New("tun device closed")
} }
} }
@ -334,10 +349,13 @@ func (tun *NativeTun) Events() chan TUNEvent {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err1 error
close(tun.statusListenersShutdown) close(tun.statusListenersShutdown)
err1 := closeUnblock(tun.netlinkSock) if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
err2 := tun.fd.Close() err2 := tun.fd.Close()
err3 := tun.rwcancel.Cancel() err3 := tun.fdCancel.Cancel()
close(tun.events) close(tun.events)
if err1 != nil { if err1 != nil {
@ -404,13 +422,15 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
} }
var err error var err error
tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd())) tun.fdCancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
if err != nil { if err != nil {
tun.fd.Close()
return nil, err return nil, err
} }
_, err = tun.Name() _, err = tun.Name()
if err != nil { if err != nil {
tun.fd.Close()
return nil, err return nil, err
} }
@ -423,6 +443,12 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
tun.netlinkSock, err = createNetlinkSocket() tun.netlinkSock, err = createNetlinkSocket()
if err != nil { if err != nil {
tun.fd.Close()
return nil, err
}
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
if err != nil {
tun.fd.Close()
return nil, err return nil, err
} }