//go:build linux && !android package server import ( "archive/zip" "bytes" "encoding/binary" "fmt" "os/exec" "sort" "strings" "github.com/google/nftables" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/proto" ) // addFirewallRules collects and adds firewall rules to the archive func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { log.Info("Collecting firewall rules") // Collect and add iptables rules iptablesRules, err := collectIPTablesRules() if err != nil { log.Warnf("Failed to collect iptables rules: %v", err) } else { if req.GetAnonymize() { iptablesRules = anonymizer.AnonymizeString(iptablesRules) } if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { log.Warnf("Failed to add iptables rules to bundle: %v", err) } } // Collect and add nftables rules nftablesRules, err := collectNFTablesRules() if err != nil { log.Warnf("Failed to collect nftables rules: %v", err) } else { if req.GetAnonymize() { nftablesRules = anonymizer.AnonymizeString(nftablesRules) } if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { log.Warnf("Failed to add nftables rules to bundle: %v", err) } } return nil } // collectIPTablesRules collects rules using both iptables-save and verbose listing func collectIPTablesRules() (string, error) { var builder strings.Builder // First try using iptables-save saveOutput, err := collectIPTablesSave() if err != nil { log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) } else { builder.WriteString("=== iptables-save output ===\n") builder.WriteString(saveOutput) builder.WriteString("\n") } // Then get verbose statistics for each table builder.WriteString("=== iptables -v -n -L output ===\n") // Get list of tables tables := []string{"filter", "nat", "mangle", "raw", "security"} for _, table := range tables { builder.WriteString(fmt.Sprintf("*%s\n", table)) // Get verbose statistics for the entire table stats, err := getTableStatistics(table) if err != nil { log.Warnf("Failed to get statistics for table %s: %v", table, err) continue } builder.WriteString(stats) builder.WriteString("\n") } return builder.String(), nil } // collectIPTablesSave uses iptables-save to get rule definitions func collectIPTablesSave() (string, error) { cmd := exec.Command("iptables-save") var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) } rules := stdout.String() if strings.TrimSpace(rules) == "" { return "", fmt.Errorf("no iptables rules found") } return rules, nil } // getTableStatistics gets verbose statistics for an entire table using iptables command func getTableStatistics(table string) (string, error) { cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) } return stdout.String(), nil } // collectNFTablesRules attempts to collect nftables rules using either nft command or netlink func collectNFTablesRules() (string, error) { // First try using nft command rules, err := collectNFTablesFromCommand() if err != nil { log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err) // Fall back to netlink rules, err = collectNFTablesFromNetlink() if err != nil { return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err) } } return rules, nil } // collectNFTablesFromCommand attempts to collect rules using nft command func collectNFTablesFromCommand() (string, error) { cmd := exec.Command("nft", "-a", "list", "ruleset") var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return "", fmt.Errorf("execute nft list ruleset: %w (stderr: %s)", err, stderr.String()) } rules := stdout.String() if strings.TrimSpace(rules) == "" { return "", fmt.Errorf("no nftables rules found") } return rules, nil } // collectNFTablesFromNetlink collects rules using netlink library func collectNFTablesFromNetlink() (string, error) { conn, err := nftables.New() if err != nil { return "", fmt.Errorf("create nftables connection: %w", err) } tables, err := conn.ListTables() if err != nil { return "", fmt.Errorf("list tables: %w", err) } sortTables(tables) return formatTables(conn, tables), nil } func formatTables(conn *nftables.Conn, tables []*nftables.Table) string { var builder strings.Builder for _, table := range tables { builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name)) chains, err := getAndSortTableChains(conn, table) if err != nil { log.Warnf("Failed to list chains for table %s: %v", table.Name, err) continue } // Format chains for _, chain := range chains { formatChain(conn, table, chain, &builder) } // Format sets if sets, err := conn.GetSets(table); err != nil { log.Warnf("Failed to get sets for table %s: %v", table.Name, err) } else if len(sets) > 0 { builder.WriteString("\n") for _, set := range sets { builder.WriteString(formatSet(conn, set)) } } builder.WriteString("}\n") } return builder.String() } func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { chains, err := conn.ListChains() if err != nil { return nil, err } var tableChains []*nftables.Chain for _, chain := range chains { if chain.Table.Name == table.Name && chain.Table.Family == table.Family { tableChains = append(tableChains, chain) } } sort.Slice(tableChains, func(i, j int) bool { return tableChains[i].Name < tableChains[j].Name }) return tableChains, nil } func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) { builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) if chain.Type != "" { var policy string if chain.Policy != nil { policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) } builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", formatChainType(chain.Type), formatChainHook(chain.Hooknum), chain.Priority, policy)) } rules, err := conn.GetRules(table, chain) if err != nil { log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) } else { sort.Slice(rules, func(i, j int) bool { return rules[i].Position < rules[j].Position }) for _, rule := range rules { builder.WriteString(formatRule(rule)) } } builder.WriteString("\t}\n") } func sortTables(tables []*nftables.Table) { sort.Slice(tables, func(i, j int) bool { if tables[i].Family != tables[j].Family { return tables[i].Family < tables[j].Family } return tables[i].Name < tables[j].Name }) } func formatFamily(family nftables.TableFamily) string { switch family { case nftables.TableFamilyIPv4: return "ip" case nftables.TableFamilyIPv6: return "ip6" case nftables.TableFamilyINet: return "inet" case nftables.TableFamilyARP: return "arp" case nftables.TableFamilyBridge: return "bridge" case nftables.TableFamilyNetdev: return "netdev" default: return fmt.Sprintf("family-%d", family) } } func formatChainType(typ nftables.ChainType) string { switch typ { case nftables.ChainTypeFilter: return "filter" case nftables.ChainTypeNAT: return "nat" case nftables.ChainTypeRoute: return "route" default: return fmt.Sprintf("type-%s", typ) } } func formatChainHook(hook *nftables.ChainHook) string { if hook == nil { return "none" } switch *hook { case *nftables.ChainHookPrerouting: return "prerouting" case *nftables.ChainHookInput: return "input" case *nftables.ChainHookForward: return "forward" case *nftables.ChainHookOutput: return "output" case *nftables.ChainHookPostrouting: return "postrouting" default: return fmt.Sprintf("hook-%d", *hook) } } func formatPolicy(policy nftables.ChainPolicy) string { switch policy { case nftables.ChainPolicyDrop: return "drop" case nftables.ChainPolicyAccept: return "accept" default: return fmt.Sprintf("policy-%d", policy) } } func formatRule(rule *nftables.Rule) string { var builder strings.Builder builder.WriteString("\t\t") for i := 0; i < len(rule.Exprs); i++ { if i > 0 { builder.WriteString(" ") } i = formatExprSequence(&builder, rule.Exprs, i) } builder.WriteString("\n") return builder.String() } func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { curr := exprs[i] // Handle Meta + Cmp sequence if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { if cmp, ok := exprs[i+1].(*expr.Cmp); ok { if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { builder.WriteString(formatted) return i + 1 } } } // Handle Payload + Cmp sequence if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { if cmp, ok := exprs[i+1].(*expr.Cmp); ok { builder.WriteString(formatPayloadWithCmp(payload, cmp)) return i + 1 } } builder.WriteString(formatExpr(curr)) return i } func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { switch meta.Key { case expr.MetaKeyIIFNAME: name := strings.TrimRight(string(cmp.Data), "\x00") return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name) case expr.MetaKeyOIFNAME: name := strings.TrimRight(string(cmp.Data), "\x00") return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name) case expr.MetaKeyMARK: if len(cmp.Data) == 4 { val := binary.BigEndian.Uint32(cmp.Data) return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val) } } return "" } func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { if p.Base == expr.PayloadBaseNetworkHeader { switch p.Offset { case 12: // Source IP if p.Len == 4 { return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } else if p.Len == 2 { return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } case 16: // Destination IP if p.Len == 4 { return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } else if p.Len == 2 { return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } } } return fmt.Sprintf("%d reg%d [%d:%d] %s %v", p.Base, p.DestRegister, p.Offset, p.Len, formatCmpOp(cmp.Op), cmp.Data) } func formatIPBytes(data []byte) string { if len(data) == 4 { return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) } else if len(data) == 2 { return fmt.Sprintf("%d.%d.0.0/16", data[0], data[1]) } return fmt.Sprintf("%v", data) } func formatCmpOp(op expr.CmpOp) string { switch op { case expr.CmpOpEq: return "==" case expr.CmpOpNeq: return "!=" case expr.CmpOpLt: return "<" case expr.CmpOpLte: return "<=" case expr.CmpOpGt: return ">" case expr.CmpOpGte: return ">=" default: return fmt.Sprintf("op-%d", op) } } // formatExpr formats an expression in nft-like syntax func formatExpr(exp expr.Any) string { switch e := exp.(type) { case *expr.Meta: return formatMeta(e) case *expr.Cmp: return formatCmp(e) case *expr.Payload: return formatPayload(e) case *expr.Verdict: return formatVerdict(e) case *expr.Counter: return fmt.Sprintf("counter packets %d bytes %d", e.Packets, e.Bytes) case *expr.Masq: return "masquerade" case *expr.NAT: return formatNat(e) case *expr.Match: return formatMatch(e) case *expr.Queue: return fmt.Sprintf("queue num %d", e.Num) case *expr.Lookup: return fmt.Sprintf("@%s", e.SetName) case *expr.Bitwise: return formatBitwise(e) case *expr.Fib: return formatFib(e) case *expr.Target: return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets case *expr.Immediate: if e.Register == 1 { return formatImmediateData(e.Data) } return fmt.Sprintf("immediate %v", e.Data) default: return fmt.Sprintf("<%T>", exp) } } func formatImmediateData(data []byte) string { // For IP addresses (4 bytes) if len(data) == 4 { return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) } return fmt.Sprintf("%v", data) } func formatMeta(e *expr.Meta) string { // Handle source register case first (meta mark set) if e.SourceRegister { return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register) } // For interface names, handle register load operation switch e.Key { case expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME, expr.MetaKeyBRIIIFNAME, expr.MetaKeyBRIOIFNAME: // Simply the key name with no register reference return formatMetaKey(e.Key) case expr.MetaKeyMARK: // For mark operations, we want just "mark" return "mark" } // For other meta keys, show as loading into register return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register) } func formatMetaKey(key expr.MetaKey) string { switch key { case expr.MetaKeyLEN: return "length" case expr.MetaKeyPROTOCOL: return "protocol" case expr.MetaKeyPRIORITY: return "priority" case expr.MetaKeyMARK: return "mark" case expr.MetaKeyIIF: return "iif" case expr.MetaKeyOIF: return "oif" case expr.MetaKeyIIFNAME: return "iifname" case expr.MetaKeyOIFNAME: return "oifname" case expr.MetaKeyIIFTYPE: return "iiftype" case expr.MetaKeyOIFTYPE: return "oiftype" case expr.MetaKeySKUID: return "skuid" case expr.MetaKeySKGID: return "skgid" case expr.MetaKeyNFTRACE: return "nftrace" case expr.MetaKeyRTCLASSID: return "rtclassid" case expr.MetaKeySECMARK: return "secmark" case expr.MetaKeyNFPROTO: return "nfproto" case expr.MetaKeyL4PROTO: return "l4proto" case expr.MetaKeyBRIIIFNAME: return "briifname" case expr.MetaKeyBRIOIFNAME: return "broifname" case expr.MetaKeyPKTTYPE: return "pkttype" case expr.MetaKeyCPU: return "cpu" case expr.MetaKeyIIFGROUP: return "iifgroup" case expr.MetaKeyOIFGROUP: return "oifgroup" case expr.MetaKeyCGROUP: return "cgroup" case expr.MetaKeyPRANDOM: return "prandom" default: return fmt.Sprintf("meta-%d", key) } } func formatCmp(e *expr.Cmp) string { ops := map[expr.CmpOp]string{ expr.CmpOpEq: "==", expr.CmpOpNeq: "!=", expr.CmpOpLt: "<", expr.CmpOpLte: "<=", expr.CmpOpGt: ">", expr.CmpOpGte: ">=", } return fmt.Sprintf("%s %v", ops[e.Op], e.Data) } func formatPayload(e *expr.Payload) string { var proto string switch e.Base { case expr.PayloadBaseNetworkHeader: proto = "ip" case expr.PayloadBaseTransportHeader: proto = "tcp" default: proto = fmt.Sprintf("payload-%d", e.Base) } return fmt.Sprintf("%s reg%d [%d:%d]", proto, e.DestRegister, e.Offset, e.Len) } func formatVerdict(e *expr.Verdict) string { switch e.Kind { case expr.VerdictAccept: return "accept" case expr.VerdictDrop: return "drop" case expr.VerdictJump: return fmt.Sprintf("jump %s", e.Chain) case expr.VerdictGoto: return fmt.Sprintf("goto %s", e.Chain) case expr.VerdictReturn: return "return" default: return fmt.Sprintf("verdict-%d", e.Kind) } } func formatNat(e *expr.NAT) string { switch e.Type { case expr.NATTypeSourceNAT: return "snat" case expr.NATTypeDestNAT: return "dnat" default: return fmt.Sprintf("nat-%d", e.Type) } } func formatMatch(e *expr.Match) string { return fmt.Sprintf("match %s rev %d", e.Name, e.Rev) } func formatBitwise(e *expr.Bitwise) string { return fmt.Sprintf("bitwise reg%d = reg%d & %v ^ %v", e.DestRegister, e.SourceRegister, e.Mask, e.Xor) } func formatFib(e *expr.Fib) string { var flags []string if e.FlagSADDR { flags = append(flags, "saddr") } if e.FlagDADDR { flags = append(flags, "daddr") } if e.FlagMARK { flags = append(flags, "mark") } if e.FlagIIF { flags = append(flags, "iif") } if e.FlagOIF { flags = append(flags, "oif") } if e.ResultADDRTYPE { flags = append(flags, "type") } return fmt.Sprintf("fib reg%d %s", e.Register, strings.Join(flags, ",")) } func formatSet(conn *nftables.Conn, set *nftables.Set) string { var builder strings.Builder builder.WriteString(fmt.Sprintf("\tset %s {\n", set.Name)) builder.WriteString(fmt.Sprintf("\t\ttype %s\n", formatSetKeyType(set.KeyType))) if set.ID > 0 { builder.WriteString(fmt.Sprintf("\t\t# handle %d\n", set.ID)) } elements, err := conn.GetSetElements(set) if err != nil { log.Warnf("Failed to get elements for set %s: %v", set.Name, err) } else if len(elements) > 0 { builder.WriteString("\t\telements = {") for i, elem := range elements { if i > 0 { builder.WriteString(", ") } builder.WriteString(fmt.Sprintf("%v", elem.Key)) } builder.WriteString("}\n") } builder.WriteString("\t}\n") return builder.String() } func formatSetKeyType(keyType nftables.SetDatatype) string { switch keyType { case nftables.TypeInvalid: return "invalid" case nftables.TypeIPAddr: return "ipv4_addr" case nftables.TypeIP6Addr: return "ipv6_addr" case nftables.TypeEtherAddr: return "ether_addr" case nftables.TypeInetProto: return "inet_proto" case nftables.TypeInetService: return "inet_service" case nftables.TypeMark: return "mark" default: return fmt.Sprintf("type-%v", keyType) } }