package iface import ( "bufio" "bytes" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" "io" "io/ioutil" "os" "path/filepath" "strings" "testing" ) func TestGetModuleDependencies(t *testing.T) { testCases := []struct { name string module string expected []module }{ { name: "Get Single Dependency", module: "bar", expected: []module{ {name: "foo", path: "kernel/a/foo.ko"}, }, }, { name: "Get Multiple Dependencies", module: "baz", expected: []module{ {name: "foo", path: "kernel/a/foo.ko"}, {name: "bar", path: "kernel/a/bar.ko"}, }, }, { name: "Get No Dependencies", module: "foo", expected: []module{}, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { defer resetGlobals() _, _ = createFiles(t) modules, err := getModuleDependencies(testCase.module) require.NoError(t, err) expected := testCase.expected for i := range expected { expected[i].path = moduleRoot + "/" + expected[i].path } require.ElementsMatchf(t, modules, expected, "returned modules should match") }) } } func TestIsBuiltinModule(t *testing.T) { testCases := []struct { name string module string expected bool }{ { name: "Built In Should Return True", module: "foo_bi", expected: true, }, { name: "Not Built In Should Return False", module: "not_built_in", expected: false, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { defer resetGlobals() _, _ = createFiles(t) isBuiltIn, err := isBuiltinModule(testCase.module) require.NoError(t, err) require.Equal(t, testCase.expected, isBuiltIn) }) } } func TestModuleStatus(t *testing.T) { random, err := getRandomLoadedModule(t) if err != nil { t.Fatal("should be able to get random module") } testCases := []struct { name string module string shouldBeLoaded bool }{ { name: "Should Return Module Loading Or Greater Status", module: random, shouldBeLoaded: true, }, { name: "Should Return Module Unloaded Or Lower Status", module: "not_loaded_module", shouldBeLoaded: false, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { defer resetGlobals() _, _ = createFiles(t) state, err := moduleStatus(testCase.module) require.NoError(t, err) if testCase.shouldBeLoaded { require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module) } else { require.Less(t, state, loading, "module should return state unloading or lower") } }) } } func resetGlobals() { moduleLibDir = defaultModuleDir moduleRoot = getModuleRoot() } func createFiles(t *testing.T) (string, []module) { writeFile := func(path, text string) { if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil { t.Fatal(err) } } var u unix.Utsname if err := unix.Uname(&u); err != nil { t.Fatal(err) } moduleLibDir = t.TempDir() moduleRoot = getModuleRoot() if err := os.Mkdir(moduleRoot, 0755); err != nil { t.Fatal(err) } text := "kernel/a/foo.ko:\n" text += "kernel/a/bar.ko: kernel/a/foo.ko\n" text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n" writeFile(filepath.Join(moduleRoot, "/modules.dep"), text) text = "kernel/a/foo_bi.ko\n" text += "kernel/a/bar-bi.ko.gz\n" writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text) modules := []module{ {name: "foo", path: "kernel/a/foo.ko"}, {name: "bar", path: "kernel/a/bar.ko"}, {name: "baz", path: "kernel/a/baz.ko"}, } return moduleLibDir, modules } func getRandomLoadedModule(t *testing.T) (string, error) { f, err := os.Open("/proc/modules") if err != nil { return "", err } defer func() { err := f.Close() if err != nil { t.Logf("failed closing /proc/modules file, %v", err) } }() lines, err := lineCounter(f) if err != nil { return "", err } counter := 1 midLine := lines / 2 modName := "" scanner := bufio.NewScanner(f) for scanner.Scan() { fields := strings.Fields(scanner.Text()) if counter == midLine { if fields[4] == "Unloading" { continue } modName = fields[0] break } counter++ } if scanner.Err() != nil { return "", scanner.Err() } return modName, nil } func lineCounter(r io.Reader) (int, error) { buf := make([]byte, 32*1024) count := 0 lineSep := []byte{'\n'} for { c, err := r.Read(buf) count += bytes.Count(buf[:c], lineSep) switch { case err == io.EOF: return count, nil case err != nil: return count, err } } }