refactor add filter acl

This commit is contained in:
Maycon Santos 2023-11-03 15:25:18 +01:00
parent e2f27502e4
commit b6af524187
5 changed files with 109 additions and 112 deletions

View File

@ -47,16 +47,7 @@ type Manager interface {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,
) (Rule, error)
AddFiltering(ruleRequest RuleRequest) ([]Rule, error)
// DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error
@ -69,3 +60,23 @@ type Manager interface {
// TODO: migrate routemanager firewal actions to this interface
}
// RuleRequest is the request to create a rule
type RuleRequest struct {
// IP is the IP address of the rule
IP net.IP
// Proto is the protocol of the rule
Proto Protocol
// SrcPort is the source port of the rule
SrcPort *Port
// DstPort is the destination port of the rule
DstPort *Port
// Direction is the direction of the rule
Direction RuleDirection
// Action is the action of the rule
Action Action
// IPSetName is the name of the IPSet
IPSetName string
// Comment is the comment of the rule
Comment string
}

View File

@ -94,33 +94,24 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// AddFiltering rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering(
ip net.IP,
protocol fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
client, err := m.client(ip)
client, err := m.client(request.IP)
if err != nil {
return nil, err
}
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
if request.DstPort != nil && request.DstPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
dPortVal = strconv.Itoa(request.DstPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
if request.SrcPort != nil && request.SrcPort.Values != nil {
sPortVal = strconv.Itoa(request.SrcPort.Values[0])
}
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ipsetName := m.transformIPsetName(request.IPSetName, sPortVal, dPortVal)
ruleID := uuid.New().String()
@ -135,28 +126,28 @@ func (m *Manager) AddFiltering(
}
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
if err := ipset.Add(ipsetName, request.IP.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
rs.ips[request.IP.String()] = ruleID
return []fw.Rule{&Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
ip: request.IP.String(),
dst: request.Direction == fw.RuleDirectionOUT,
v6: request.IP.To4() == nil,
}}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
specs := m.filterRuleSpecs(request.IP, string(request.Proto), sPortVal, dPortVal, request.Direction, request.Action, ipsetName)
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is output rule already exists: %w", err)
@ -186,9 +177,9 @@ func (m *Manager) AddFiltering(
ruleID: ruleID,
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
ip: request.IP.String(),
dst: request.Direction == fw.RuleDirectionOUT,
v6: request.IP.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
@ -199,7 +190,7 @@ func (m *Manager) AddFiltering(
}
}
return rule, nil
return []fw.Rule{rule}, nil
}
// DeleteRule from the firewall by rule definition
@ -272,27 +263,31 @@ func (m *Manager) Reset() error {
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"",
fw.RuleRequest{
IP: net.ParseIP("0.0.0.0"),
Proto: "all",
SrcPort: nil,
DstPort: nil,
Direction: fw.RuleDirectionIN,
Action: fw.ActionAccept,
IPSetName: "",
Comment: "",
},
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"",
fw.RuleRequest{
IP: net.ParseIP("0.0.0.0"),
Proto: "all",
SrcPort: nil,
DstPort: nil,
Direction: fw.RuleDirectionOUT,
Action: fw.ActionAccept,
IPSetName: "",
Comment: "",
},
)
return err
}

View File

@ -96,16 +96,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -116,16 +107,16 @@ func (m *Manager) AddFiltering(
chain *nftables.Chain
)
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
ip,
request.IP,
FilterOutputChainName,
nftables.ChainHookOutput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
} else {
table, chain, err = m.chain(
ip,
request.IP,
FilterInputChainName,
nftables.ChainHookInput,
nftables.ChainPriorityFilter,
@ -135,22 +126,22 @@ func (m *Manager) AddFiltering(
return nil, err
}
rawIP := ip.To4()
rawIP := request.IP.To4()
if rawIP == nil {
rawIP = ip.To16()
rawIP = request.IP.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
rulesetID := m.getRulesetID(request.IP, request.Proto, request.SrcPort, request.DstPort, request.Direction, request.Action, request.IPSetName)
if ipsetName != "" {
if request.IPSetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
ipset, err = m.rConn.GetSetByName(table, request.IPSetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
if ipset, err = m.createSet(table, rawIP, request.IPSetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
@ -168,7 +159,11 @@ func (m *Manager) AddFiltering(
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
@ -176,7 +171,7 @@ func (m *Manager) AddFiltering(
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
@ -188,7 +183,7 @@ func (m *Manager) AddFiltering(
},
}
if proto != "all" {
if request.Proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@ -197,7 +192,7 @@ func (m *Manager) AddFiltering(
})
var protoData []byte
switch proto {
switch request.Proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
@ -205,7 +200,7 @@ func (m *Manager) AddFiltering(
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
return nil, fmt.Errorf("unsupported protocol: %s", request.Proto)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
@ -225,7 +220,7 @@ func (m *Manager) AddFiltering(
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
addrOffset += addrLen
}
@ -250,14 +245,14 @@ func (m *Manager) AddFiltering(
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetName: request.IPSetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
if request.SrcPort != nil && len(request.SrcPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
@ -268,12 +263,12 @@ func (m *Manager) AddFiltering(
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
Data: encodePort(*request.SrcPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 {
if request.DstPort != nil && len(request.DstPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
@ -284,18 +279,18 @@ func (m *Manager) AddFiltering(
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
Data: encodePort(*request.DstPort),
},
)
}
if action == fw.ActionAccept {
if request.Action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
userData := []byte(strings.Join([]string{rulesetID, request.Comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
Table: table,
@ -309,7 +304,11 @@ func (m *Manager) AddFiltering(
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
}
// getRulesetID returns ruleset ID based on given parameters

View File

@ -81,26 +81,17 @@ func Create(iface IFaceMapper) (*Manager, error) {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
r := Rule{
id: uuid.New().String(),
ip: ip,
ip: request.IP,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == fw.ActionDrop,
comment: comment,
direction: request.Direction,
drop: request.Action == fw.ActionDrop,
comment: request.Comment,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
if ipNormalized := request.IP.To4(); ipNormalized != nil {
r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
}
@ -109,15 +100,15 @@ func (m *Manager) AddFiltering(
r.matchByIP = false
}
if sPort != nil && len(sPort.Values) == 1 {
r.sPort = uint16(sPort.Values[0])
if request.SrcPort != nil && len(request.SrcPort.Values) == 1 {
r.sPort = uint16(request.SrcPort.Values[0])
}
if dPort != nil && len(dPort.Values) == 1 {
r.dPort = uint16(dPort.Values[0])
if request.DstPort != nil && len(request.DstPort.Values) == 1 {
r.dPort = uint16(request.DstPort.Values[0])
}
switch proto {
switch request.Proto {
case fw.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case fw.ProtocolUDP:
@ -132,7 +123,7 @@ func (m *Manager) AddFiltering(
}
m.mutex.Lock()
if direction == fw.RuleDirectionIN {
if request.Direction == fw.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
@ -145,7 +136,7 @@ func (m *Manager) AddFiltering(
}
m.mutex.Unlock()
return &r, nil
return []fw.Rule{&r}, nil
}
// DeleteRule from the firewall by rule definition

View File

@ -287,13 +287,14 @@ func (d *DefaultManager) addOutRules(
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule)
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
@ -305,7 +306,7 @@ func (d *DefaultManager) addOutRules(
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule), nil
return append(rules, rule...), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.