From 355e9bd619c8ec8bdd0f2934739a87a15ceb6920 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 14 May 2018 12:27:29 +0200 Subject: [PATCH] Clean more --- conn_linux.go | 17 ++++------------- cookie.go | 32 ++++++++++++++++---------------- device.go | 14 ++++---------- main.go | 2 +- noise-protocol.go | 10 +++++----- tun_darwin.go | 13 +++++++------ tun_linux.go | 29 +++++++++++++++-------------- 7 files changed, 52 insertions(+), 65 deletions(-) diff --git a/conn_linux.go b/conn_linux.go index 2b920bf..8d076ac 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -217,19 +217,6 @@ func (bind *NativeBind) Send(buff []byte, end Endpoint) error { } } -func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP { - return net.IPv4( - addr.Addr[0], - addr.Addr[1], - addr.Addr[2], - addr.Addr[3], - ) -} - -func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP { - return addr.Addr[:] -} - func (end *NativeEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( @@ -624,6 +611,10 @@ func (bind *NativeBind) routineRouteListener(device *Device) { peer.mutex.RUnlock() continue } + if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { + peer.mutex.RUnlock() + break + } nlmsg := struct { hdr unix.NlMsghdr msg unix.RtMsg diff --git a/cookie.go b/cookie.go index cfee367..c0d3ed9 100644 --- a/cookie.go +++ b/cookie.go @@ -48,19 +48,19 @@ func (st *CookieChecker) Init(pk NoisePublicKey) { // mac1 state func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) }() // mac2 state func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) }() st.mac2.secretSet = time.Time{} @@ -181,17 +181,17 @@ func (st *CookieGenerator) Init(pk NoisePublicKey) { defer st.mutex.Unlock() func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) }() func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) }() st.mac2.cookieSet = time.Time{} diff --git a/device.go b/device.go index 835a755..e91ca72 100644 --- a/device.go +++ b/device.go @@ -225,15 +225,15 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { for key, peer := range device.peers.keyMap { - hs := &peer.handshake + handshake := &peer.handshake if rmKey { - hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} + handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} } else { - hs.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(hs.remoteStatic) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) } - if isZero(hs.precomputedStaticStatic[:]) { + if isZero(handshake.precomputedStaticStatic[:]) { unsafeRemovePeer(device, peer, key) } } @@ -267,18 +267,12 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { device.peers.keyMap = make(map[NoisePublicKey]*Peer) - // initialize rate limiter - device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) - // initialize staticIdentity & crypt-key routine - device.indexTable.Init() device.allowedips.Reset() - // setup buffer pool - device.pool.messageBuffers = sync.Pool{ New: func() interface{} { return new([MaxMessageSize]byte) diff --git a/main.go b/main.go index 6c7b07d..c9ef343 100644 --- a/main.go +++ b/main.go @@ -186,7 +186,7 @@ func main() { env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) files := [3]*os.File{} - if os.Getenv("LOG_LEVEL") != "" { + if os.Getenv("LOG_LEVEL") != "" && logLevel != LogLevelSilent { files[1] = os.Stdout files[2] = os.Stderr } diff --git a/noise-protocol.go b/noise-protocol.go index ffc2b50..c134107 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -121,11 +121,11 @@ func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { } func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { - hsh, _ := blake2s.New256(nil) - hsh.Write(h[:]) - hsh.Write(data) - hsh.Sum(dst[:0]) - hsh.Reset() + hash, _ := blake2s.New256(nil) + hash.Write(h[:]) + hash.Write(data) + hash.Sum(dst[:0]) + hash.Reset() } func (h *Handshake) Clear() { diff --git a/tun_darwin.go b/tun_darwin.go index b212e57..ac8bffd 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -125,12 +125,6 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) { return nil, err } - // set default MTU - err = tun.setMTU(DefaultMTU) - if err != nil { - return nil, err - } - tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd())) if err != nil { return nil, err @@ -174,6 +168,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) { } }(tun) + // set default MTU + err = tun.setMTU(DefaultMTU) + if err != nil { + tun.Close() + return nil, err + } + return tun, nil } diff --git a/tun_linux.go b/tun_linux.go index 18fb72c..8e42d44 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -395,7 +395,7 @@ func CreateTUN(name string) (TUNDevice, error) { } func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { - device := &NativeTun{ + tun := &NativeTun{ fd: fd, events: make(chan TUNEvent, 5), errors: make(chan error, 5), @@ -404,37 +404,38 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { } var err error - device.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd())) + tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd())) if err != nil { return nil, err } - _, err = device.Name() + _, err = tun.Name() if err != nil { return nil, err } // start event listener - device.index, err = getIFIndex(device.name) + tun.index, err = getIFIndex(tun.name) if err != nil { return nil, err } + tun.netlinkSock, err = createNetlinkSocket() + if err != nil { + return nil, err + } + + go tun.RoutineNetlinkListener() + go tun.RoutineHackListener() // cross namespace + // set default MTU - err = device.setMTU(DefaultMTU) + err = tun.setMTU(DefaultMTU) if err != nil { + tun.Close() return nil, err } - device.netlinkSock, err = createNetlinkSocket() - if err != nil { - return nil, err - } - - go device.RoutineNetlinkListener() - go device.RoutineHackListener() // cross namespace - - return device, nil + return tun, nil }