mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 01:23:22 +01:00
Improve module load (#470)
* Add additional check for needed kernel modules * Check if wireguard and tun modules are loaded If modules are loaded return true, otherwise attempt to load them * fix state check * Add module function tests * Add test execution in container * run client package tests on docker * add package comment to new file * force entrypoint * add --privileged flag * clean only if tables where created * run from within the directories
This commit is contained in:
parent
6de313070a
commit
e4ad6174ca
52
.github/workflows/golang-test-linux.yml
vendored
52
.github/workflows/golang-test-linux.yml
vendored
@ -33,3 +33,55 @@ jobs:
|
|||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||||
|
|
||||||
|
test_client_on_docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: 1.18.x
|
||||||
|
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: Generate Iface Test bin
|
||||||
|
run: go test -c -o iface-testing.bin ./iface/...
|
||||||
|
|
||||||
|
- name: Generate RouteManager Test bin
|
||||||
|
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
|
- name: Generate Engine Test bin
|
||||||
|
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
||||||
|
|
||||||
|
- name: Generate Peer Test bin
|
||||||
|
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||||
|
|
||||||
|
- run: chmod +x *testing.bin
|
||||||
|
|
||||||
|
- name: Run Iface tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin
|
||||||
|
|
||||||
|
- name: Run RouteManager tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin
|
||||||
|
|
||||||
|
- name: Run Engine tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin
|
||||||
|
|
||||||
|
- name: Run Peer tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin
|
@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
localPeerState := nbStatus.LocalPeerState{
|
localPeerState := nbStatus.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModExists(),
|
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||||
}
|
}
|
||||||
|
|
||||||
statusRecorder.UpdateLocalPeerState(localPeerState)
|
statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
@ -84,8 +84,10 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
defer n.mux.Unlock()
|
defer n.mux.Unlock()
|
||||||
log.Debug("flushing tables")
|
log.Debug("flushing tables")
|
||||||
n.conn.FlushTable(n.tableIPv6)
|
if n.tableIPv4 != nil && n.tableIPv6 != nil {
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
|
}
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
func WireguardModExists() bool {
|
func WireguardModuleIsLoaded() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -1,48 +1,29 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NativeLink struct {
|
type NativeLink struct {
|
||||||
Link *netlink.Link
|
Link *netlink.Link
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
|
||||||
func WireguardModExists() bool {
|
|
||||||
link := newWGLink("mustnotexist")
|
|
||||||
|
|
||||||
// We willingly try to create a device with an invalid
|
|
||||||
// MTU here as the validation of the MTU will be performed after
|
|
||||||
// the validation of the link kind and hence allows us to check
|
|
||||||
// for the existance of the wireguard module without actually
|
|
||||||
// creating a link.
|
|
||||||
//
|
|
||||||
// As a side-effect, this will also let the kernel lazy-load
|
|
||||||
// the wireguard module.
|
|
||||||
link.attrs.MTU = math.MaxInt
|
|
||||||
|
|
||||||
err := netlink.LinkAdd(link)
|
|
||||||
|
|
||||||
return errors.Is(err, syscall.EINVAL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
if WireguardModExists() {
|
if WireguardModuleIsLoaded() {
|
||||||
log.Info("using kernel WireGuard")
|
log.Info("using kernel WireGuard")
|
||||||
return w.createWithKernel()
|
return w.createWithKernel()
|
||||||
} else {
|
} else {
|
||||||
|
if !tunModuleIsLoaded() {
|
||||||
|
return fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
log.Info("using userspace WireGuard")
|
log.Info("using userspace WireGuard")
|
||||||
return w.createWithUserspace()
|
return w.createWithUserspace()
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
return w.assignAddr(luid)
|
return w.assignAddr(luid)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
func WireguardModExists() bool {
|
func WireguardModuleIsLoaded() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
349
iface/module_linux.go
Normal file
349
iface/module_linux.go
Normal file
@ -0,0 +1,349 @@
|
|||||||
|
// Package iface provides wireguard network interface creation and management
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"io/fs"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Holds logic to check existence of kernel modules used by wireguard interfaces
|
||||||
|
// Copied from https://github.com/paultag/go-modprobe and
|
||||||
|
// https://github.com/pmorjan/kmod
|
||||||
|
|
||||||
|
type status int
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultModuleDir = "/lib/modules"
|
||||||
|
unknown status = iota
|
||||||
|
unloaded
|
||||||
|
unloading
|
||||||
|
loading
|
||||||
|
live
|
||||||
|
inuse
|
||||||
|
)
|
||||||
|
|
||||||
|
type module struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrModuleNotFound is the error resulting if a module can't be found.
|
||||||
|
ErrModuleNotFound = errors.New("module not found")
|
||||||
|
moduleLibDir = defaultModuleDir
|
||||||
|
// get the root directory for the kernel modules. If this line panics,
|
||||||
|
// it's because getModuleRoot has failed to get the uname of the running
|
||||||
|
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get the module root (/lib/modules/$(uname -r)/)
|
||||||
|
func getModuleRoot() string {
|
||||||
|
uname := unix.Utsname{}
|
||||||
|
if err := unix.Uname(&uname); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for ; uname.Release[i] != 0; i++ {
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
|
||||||
|
func tunModuleIsLoaded() bool {
|
||||||
|
_, err := os.Stat("/dev/net/tun")
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
|
||||||
|
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
|
||||||
|
|
||||||
|
tunLoaded, err := tryToLoadModule("tun")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to find or load tun module, got error: %v", err)
|
||||||
|
}
|
||||||
|
return tunLoaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
|
func WireguardModuleIsLoaded() bool {
|
||||||
|
if canCreateFakeWireguardInterface() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := tryToLoadModule("wireguard")
|
||||||
|
if err != nil {
|
||||||
|
log.Info(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
func canCreateFakeWireguardInterface() bool {
|
||||||
|
link := newWGLink("mustnotexist")
|
||||||
|
|
||||||
|
// We willingly try to create a device with an invalid
|
||||||
|
// MTU here as the validation of the MTU will be performed after
|
||||||
|
// the validation of the link kind and hence allows us to check
|
||||||
|
// for the existance of the wireguard module without actually
|
||||||
|
// creating a link.
|
||||||
|
//
|
||||||
|
// As a side-effect, this will also let the kernel lazy-load
|
||||||
|
// the wireguard module.
|
||||||
|
link.attrs.MTU = math.MaxInt
|
||||||
|
|
||||||
|
err := netlink.LinkAdd(link)
|
||||||
|
|
||||||
|
return errors.Is(err, syscall.EINVAL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryToLoadModule(moduleName string) (bool, error) {
|
||||||
|
if isModuleEnabled(moduleName) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
modulePath, err := getModulePath(moduleName)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
|
||||||
|
}
|
||||||
|
if modulePath == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("trying to load %s module", moduleName)
|
||||||
|
|
||||||
|
err = loadModuleWithDependencies(moduleName, modulePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isModuleEnabled(name string) bool {
|
||||||
|
builtin, builtinErr := isBuiltinModule(name)
|
||||||
|
state, statusErr := moduleStatus(name)
|
||||||
|
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModulePath(name string) (string, error) {
|
||||||
|
var foundPath string
|
||||||
|
skipRemainingDirs := false
|
||||||
|
|
||||||
|
err := filepath.WalkDir(
|
||||||
|
moduleRoot,
|
||||||
|
func(path string, info fs.DirEntry, err error) error {
|
||||||
|
if skipRemainingDirs {
|
||||||
|
return fs.SkipDir
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// skip broken files
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.Type().IsRegular() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nameFromPath := pathToName(path)
|
||||||
|
if nameFromPath == name {
|
||||||
|
foundPath = path
|
||||||
|
skipRemainingDirs = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return foundPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathToName(s string) string {
|
||||||
|
s = filepath.Base(s)
|
||||||
|
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
|
||||||
|
s = strings.TrimSuffix(s, ext)
|
||||||
|
}
|
||||||
|
return cleanName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanName(s string) string {
|
||||||
|
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBuiltinModule(name string) (bool, error) {
|
||||||
|
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing modules.builtin file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if pathToName(line) == name {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// /proc/modules
|
||||||
|
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
|
||||||
|
// macvlan 28672 1 macvtap, Live 0x0000000000000000
|
||||||
|
func moduleStatus(name string) (status, error) {
|
||||||
|
state := unknown
|
||||||
|
f, err := os.Open("/proc/modules")
|
||||||
|
if err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing /proc/modules file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
state = unloaded
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if fields[0] == name {
|
||||||
|
if fields[2] != "0" {
|
||||||
|
state = inuse
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch fields[4] {
|
||||||
|
case "Live":
|
||||||
|
state = live
|
||||||
|
case "Loading":
|
||||||
|
state = loading
|
||||||
|
case "Unloading":
|
||||||
|
state = unloading
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModuleWithDependencies(name, path string) error {
|
||||||
|
deps, err := getModuleDependencies(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't load list of module %s dependecies", name)
|
||||||
|
}
|
||||||
|
for _, dep := range deps {
|
||||||
|
err = loadModule(dep.name, dep.path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return loadModule(name, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModule(name, path string) error {
|
||||||
|
state, err := moduleStatus(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if state >= loading {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing %s file, %v", path, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// first try finit_module(2), then init_module(2)
|
||||||
|
err = unix.FinitModule(int(f.Fd()), "", 0)
|
||||||
|
if errors.Is(err, unix.ENOSYS) {
|
||||||
|
buf, err := ioutil.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return unix.InitModule(buf, "")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getModuleDependencies returns a module dependencies
|
||||||
|
func getModuleDependencies(name string) ([]module, error) {
|
||||||
|
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing modules.dep file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var deps []string
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
|
||||||
|
deps = fields
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(deps) == 0 {
|
||||||
|
return nil, ErrModuleNotFound
|
||||||
|
}
|
||||||
|
deps[0] = strings.TrimSuffix(deps[0], ":")
|
||||||
|
|
||||||
|
var modules []module
|
||||||
|
for _, v := range deps {
|
||||||
|
if pathToName(v) != name {
|
||||||
|
modules = append(modules, module{
|
||||||
|
name: pathToName(v),
|
||||||
|
path: filepath.Join(moduleRoot, v),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modules, nil
|
||||||
|
}
|
221
iface/module_linux_test.go
Normal file
221
iface/module_linux_test.go
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModuleDependencies(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
expected []module
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Single Dependency",
|
||||||
|
module: "bar",
|
||||||
|
expected: []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Multiple Dependencies",
|
||||||
|
module: "baz",
|
||||||
|
expected: []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
{name: "bar", path: "kernel/a/bar.ko"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get No Dependencies",
|
||||||
|
module: "foo",
|
||||||
|
expected: []module{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
modules, err := getModuleDependencies(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected := testCase.expected
|
||||||
|
for i := range expected {
|
||||||
|
expected[i].path = moduleRoot + "/" + expected[i].path
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatchf(t, modules, expected, "returned modules should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBuiltinModule(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Built In Should Return True",
|
||||||
|
module: "foo_bi",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not Built In Should Return False",
|
||||||
|
module: "not_built_in",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
|
||||||
|
isBuiltIn, err := isBuiltinModule(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, testCase.expected, isBuiltIn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModuleStatus(t *testing.T) {
|
||||||
|
random, err := getRandomLoadedModule(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("should be able to get random module")
|
||||||
|
}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
shouldBeLoaded bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Return Module Loading Or Greater Status",
|
||||||
|
module: random,
|
||||||
|
shouldBeLoaded: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Return Module Unloaded Or Lower Status",
|
||||||
|
module: "not_loaded_module",
|
||||||
|
shouldBeLoaded: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
|
||||||
|
state, err := moduleStatus(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if testCase.shouldBeLoaded {
|
||||||
|
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
|
||||||
|
} else {
|
||||||
|
require.Less(t, state, loading, "module should return state unloading or lower")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetGlobals() {
|
||||||
|
moduleLibDir = defaultModuleDir
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFiles(t *testing.T) (string, []module) {
|
||||||
|
writeFile := func(path, text string) {
|
||||||
|
if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var u unix.Utsname
|
||||||
|
if err := unix.Uname(&u); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLibDir = t.TempDir()
|
||||||
|
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
if err := os.Mkdir(moduleRoot, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
text := "kernel/a/foo.ko:\n"
|
||||||
|
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
|
||||||
|
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
|
||||||
|
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
|
||||||
|
|
||||||
|
text = "kernel/a/foo_bi.ko\n"
|
||||||
|
text += "kernel/a/bar-bi.ko.gz\n"
|
||||||
|
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
|
||||||
|
|
||||||
|
modules := []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
{name: "bar", path: "kernel/a/bar.ko"},
|
||||||
|
{name: "baz", path: "kernel/a/baz.ko"},
|
||||||
|
}
|
||||||
|
return moduleLibDir, modules
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRandomLoadedModule(t *testing.T) (string, error) {
|
||||||
|
f, err := os.Open("/proc/modules")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed closing /proc/modules file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
lines, err := lineCounter(f)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
counter := 1
|
||||||
|
midLine := lines / 2
|
||||||
|
modName := ""
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if counter == midLine {
|
||||||
|
if fields[4] == "Unloading" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modName = fields[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
counter++
|
||||||
|
}
|
||||||
|
if scanner.Err() != nil {
|
||||||
|
return "", scanner.Err()
|
||||||
|
}
|
||||||
|
return modName, nil
|
||||||
|
}
|
||||||
|
func lineCounter(r io.Reader) (int, error) {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
count := 0
|
||||||
|
lineSep := []byte{'\n'}
|
||||||
|
|
||||||
|
for {
|
||||||
|
c, err := r.Read(buf)
|
||||||
|
count += bytes.Count(buf[:c], lineSep)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err == io.EOF:
|
||||||
|
return count, nil
|
||||||
|
|
||||||
|
case err != nil:
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user