refactor add filter acl

This commit is contained in:
Maycon Santos 2023-11-03 15:25:18 +01:00
parent e2f27502e4
commit b6af524187
5 changed files with 109 additions and 112 deletions

View File

@ -47,16 +47,7 @@ type Manager interface {
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddFiltering( AddFiltering(ruleRequest RuleRequest) ([]Rule, error)
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,
) (Rule, error)
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error DeleteRule(rule Rule) error
@ -69,3 +60,23 @@ type Manager interface {
// TODO: migrate routemanager firewal actions to this interface // TODO: migrate routemanager firewal actions to this interface
} }
// RuleRequest is the request to create a rule
type RuleRequest struct {
// IP is the IP address of the rule
IP net.IP
// Proto is the protocol of the rule
Proto Protocol
// SrcPort is the source port of the rule
SrcPort *Port
// DstPort is the destination port of the rule
DstPort *Port
// Direction is the direction of the rule
Direction RuleDirection
// Action is the action of the rule
Action Action
// IPSetName is the name of the IPSet
IPSetName string
// Comment is the comment of the rule
Comment string
}

View File

@ -94,33 +94,24 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
ip net.IP,
protocol fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
client, err := m.client(ip) client, err := m.client(request.IP)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var dPortVal, sPortVal string var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil { if request.DstPort != nil && request.DstPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs // TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0]) dPortVal = strconv.Itoa(request.DstPort.Values[0])
} }
if sPort != nil && sPort.Values != nil { if request.SrcPort != nil && request.SrcPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0]) sPortVal = strconv.Itoa(request.SrcPort.Values[0])
} }
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal) ipsetName := m.transformIPsetName(request.IPSetName, sPortVal, dPortVal)
ruleID := uuid.New().String() ruleID := uuid.New().String()
@ -135,28 +126,28 @@ func (m *Manager) AddFiltering(
} }
} }
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, request.IP.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err) return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
} }
if rsExists { if rsExists {
// if ruleset already exists it means we already have the firewall rule // if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID rs.ips[request.IP.String()] = ruleID
return &Rule{ return []fw.Rule{&Rule{
ruleID: ruleID, ruleID: ruleID,
ipsetName: ipsetName, ipsetName: ipsetName,
ip: ip.String(), ip: request.IP.String(),
dst: direction == fw.RuleDirectionOUT, dst: request.Direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil, v6: request.IP.To4() == nil,
}, nil }}, nil
} }
// this is new ipset so we need to create firewall rule for it // this is new ipset so we need to create firewall rule for it
} }
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) specs := m.filterRuleSpecs(request.IP, string(request.Proto), sPortVal, dPortVal, request.Direction, request.Action, ipsetName)
if direction == fw.RuleDirectionOUT { if request.Direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...) ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("check is output rule already exists: %w", err) return nil, fmt.Errorf("check is output rule already exists: %w", err)
@ -186,9 +177,9 @@ func (m *Manager) AddFiltering(
ruleID: ruleID, ruleID: ruleID,
specs: specs, specs: specs,
ipsetName: ipsetName, ipsetName: ipsetName,
ip: ip.String(), ip: request.IP.String(),
dst: direction == fw.RuleDirectionOUT, dst: request.Direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil, v6: request.IP.To4() == nil,
} }
if ipsetName != "" { if ipsetName != "" {
// ipset name is defined and it means that this rule was created // ipset name is defined and it means that this rule was created
@ -199,7 +190,7 @@ func (m *Manager) AddFiltering(
} }
} }
return rule, nil return []fw.Rule{rule}, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
@ -272,27 +263,31 @@ func (m *Manager) Reset() error {
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() { if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering( _, err := m.AddFiltering(
net.ParseIP("0.0.0.0"), fw.RuleRequest{
"all", IP: net.ParseIP("0.0.0.0"),
nil, Proto: "all",
nil, SrcPort: nil,
fw.RuleDirectionIN, DstPort: nil,
fw.ActionAccept, Direction: fw.RuleDirectionIN,
"", Action: fw.ActionAccept,
"", IPSetName: "",
Comment: "",
},
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
} }
_, err = m.AddFiltering( _, err = m.AddFiltering(
net.ParseIP("0.0.0.0"), fw.RuleRequest{
"all", IP: net.ParseIP("0.0.0.0"),
nil, Proto: "all",
nil, SrcPort: nil,
fw.RuleDirectionOUT, DstPort: nil,
fw.ActionAccept, Direction: fw.RuleDirectionOUT,
"", Action: fw.ActionAccept,
"", IPSetName: "",
Comment: "",
},
) )
return err return err
} }

View File

@ -96,16 +96,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -116,16 +107,16 @@ func (m *Manager) AddFiltering(
chain *nftables.Chain chain *nftables.Chain
) )
if direction == fw.RuleDirectionOUT { if request.Direction == fw.RuleDirectionOUT {
table, chain, err = m.chain( table, chain, err = m.chain(
ip, request.IP,
FilterOutputChainName, FilterOutputChainName,
nftables.ChainHookOutput, nftables.ChainHookOutput,
nftables.ChainPriorityFilter, nftables.ChainPriorityFilter,
nftables.ChainTypeFilter) nftables.ChainTypeFilter)
} else { } else {
table, chain, err = m.chain( table, chain, err = m.chain(
ip, request.IP,
FilterInputChainName, FilterInputChainName,
nftables.ChainHookInput, nftables.ChainHookInput,
nftables.ChainPriorityFilter, nftables.ChainPriorityFilter,
@ -135,22 +126,22 @@ func (m *Manager) AddFiltering(
return nil, err return nil, err
} }
rawIP := ip.To4() rawIP := request.IP.To4()
if rawIP == nil { if rawIP == nil {
rawIP = ip.To16() rawIP = request.IP.To16()
} }
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName) rulesetID := m.getRulesetID(request.IP, request.Proto, request.SrcPort, request.DstPort, request.Direction, request.Action, request.IPSetName)
if ipsetName != "" { if request.IPSetName != "" {
// if we already have set with given name, just add ip to the set // if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule // and return rule with new ID in other case let's create rule
// with fresh created set and set element // with fresh created set and set element
var isSetNew bool var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName) ipset, err = m.rConn.GetSetByName(table, request.IPSetName)
if err != nil { if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil { if ipset, err = m.createSet(table, rawIP, request.IPSetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err) return nil, fmt.Errorf("get set name: %v", err)
} }
isSetNew = true isSetNew = true
@ -168,7 +159,11 @@ func (m *Manager) AddFiltering(
// just add new rule to the ruleset and return new fw.Rule object // just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok { if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP) r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
} }
// if ipset exists but it is not linked to rule for given direction // if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later // create new rule for direction and bind ipset to it later
@ -176,7 +171,7 @@ func (m *Manager) AddFiltering(
} }
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT { if request.Direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME ifaceKey = expr.MetaKeyOIFNAME
} }
expressions := []expr.Any{ expressions := []expr.Any{
@ -188,7 +183,7 @@ func (m *Manager) AddFiltering(
}, },
} }
if proto != "all" { if request.Proto != "all" {
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@ -197,7 +192,7 @@ func (m *Manager) AddFiltering(
}) })
var protoData []byte var protoData []byte
switch proto { switch request.Proto {
case fw.ProtocolTCP: case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP} protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP: case fw.ProtocolUDP:
@ -205,7 +200,7 @@ func (m *Manager) AddFiltering(
case fw.ProtocolICMP: case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP} protoData = []byte{unix.IPPROTO_ICMP}
default: default:
return nil, fmt.Errorf("unsupported protocol: %s", proto) return nil, fmt.Errorf("unsupported protocol: %s", request.Proto)
} }
expressions = append(expressions, &expr.Cmp{ expressions = append(expressions, &expr.Cmp{
Register: 1, Register: 1,
@ -225,7 +220,7 @@ func (m *Manager) AddFiltering(
} }
// change to destination address position if need // change to destination address position if need
if direction == fw.RuleDirectionOUT { if request.Direction == fw.RuleDirectionOUT {
addrOffset += addrLen addrOffset += addrLen
} }
@ -250,14 +245,14 @@ func (m *Manager) AddFiltering(
expressions = append(expressions, expressions = append(expressions,
&expr.Lookup{ &expr.Lookup{
SourceRegister: 1, SourceRegister: 1,
SetName: ipsetName, SetName: request.IPSetName,
SetID: ipset.ID, SetID: ipset.ID,
}, },
) )
} }
} }
if sPort != nil && len(sPort.Values) != 0 { if request.SrcPort != nil && len(request.SrcPort.Values) != 0 {
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@ -268,12 +263,12 @@ func (m *Manager) AddFiltering(
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: encodePort(*sPort), Data: encodePort(*request.SrcPort),
}, },
) )
} }
if dPort != nil && len(dPort.Values) != 0 { if request.DstPort != nil && len(request.DstPort.Values) != 0 {
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@ -284,18 +279,18 @@ func (m *Manager) AddFiltering(
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: encodePort(*dPort), Data: encodePort(*request.DstPort),
}, },
) )
} }
if action == fw.ActionAccept { if request.Action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else { } else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{rulesetID, comment}, " ")) userData := []byte(strings.Join([]string{rulesetID, request.Comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{ rule := m.rConn.InsertRule(&nftables.Rule{
Table: table, Table: table,
@ -309,7 +304,11 @@ func (m *Manager) AddFiltering(
} }
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset) ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP) r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
} }
// getRulesetID returns ruleset ID based on given parameters // getRulesetID returns ruleset ID based on given parameters

View File

@ -81,26 +81,17 @@ func Create(iface IFaceMapper) (*Manager, error) {
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
r := Rule{ r := Rule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: request.IP,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
direction: direction, direction: request.Direction,
drop: action == fw.ActionDrop, drop: request.Action == fw.ActionDrop,
comment: comment, comment: request.Comment,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if ipNormalized := request.IP.To4(); ipNormalized != nil {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized r.ip = ipNormalized
} }
@ -109,15 +100,15 @@ func (m *Manager) AddFiltering(
r.matchByIP = false r.matchByIP = false
} }
if sPort != nil && len(sPort.Values) == 1 { if request.SrcPort != nil && len(request.SrcPort.Values) == 1 {
r.sPort = uint16(sPort.Values[0]) r.sPort = uint16(request.SrcPort.Values[0])
} }
if dPort != nil && len(dPort.Values) == 1 { if request.DstPort != nil && len(request.DstPort.Values) == 1 {
r.dPort = uint16(dPort.Values[0]) r.dPort = uint16(request.DstPort.Values[0])
} }
switch proto { switch request.Proto {
case fw.ProtocolTCP: case fw.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP r.protoLayer = layers.LayerTypeTCP
case fw.ProtocolUDP: case fw.ProtocolUDP:
@ -132,7 +123,7 @@ func (m *Manager) AddFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if direction == fw.RuleDirectionIN { if request.Direction == fw.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip.String()] = make(RuleSet)
} }
@ -145,7 +136,7 @@ func (m *Manager) AddFiltering(
} }
m.mutex.Unlock() m.mutex.Unlock()
return &r, nil return []fw.Rule{&r}, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition

View File

@ -287,13 +287,14 @@ func (d *DefaultManager) addOutRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering( rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
rules = append(rules, rule) rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return rules, nil
@ -305,7 +306,7 @@ func (d *DefaultManager) addOutRules(
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
return append(rules, rule), nil return append(rules, rule...), nil
} }
// getRuleID() returns unique ID for the rule based on its parameters. // getRuleID() returns unique ID for the rule based on its parameters.