refactor add filter acl

This commit is contained in:
Maycon Santos
2023-11-03 15:25:18 +01:00
parent e2f27502e4
commit b6af524187
5 changed files with 109 additions and 112 deletions

View File

@ -96,16 +96,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -116,16 +107,16 @@ func (m *Manager) AddFiltering(
chain *nftables.Chain
)
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
ip,
request.IP,
FilterOutputChainName,
nftables.ChainHookOutput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
} else {
table, chain, err = m.chain(
ip,
request.IP,
FilterInputChainName,
nftables.ChainHookInput,
nftables.ChainPriorityFilter,
@ -135,22 +126,22 @@ func (m *Manager) AddFiltering(
return nil, err
}
rawIP := ip.To4()
rawIP := request.IP.To4()
if rawIP == nil {
rawIP = ip.To16()
rawIP = request.IP.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
rulesetID := m.getRulesetID(request.IP, request.Proto, request.SrcPort, request.DstPort, request.Direction, request.Action, request.IPSetName)
if ipsetName != "" {
if request.IPSetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
ipset, err = m.rConn.GetSetByName(table, request.IPSetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
if ipset, err = m.createSet(table, rawIP, request.IPSetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
@ -168,7 +159,11 @@ func (m *Manager) AddFiltering(
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
@ -176,7 +171,7 @@ func (m *Manager) AddFiltering(
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
@ -188,7 +183,7 @@ func (m *Manager) AddFiltering(
},
}
if proto != "all" {
if request.Proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@ -197,7 +192,7 @@ func (m *Manager) AddFiltering(
})
var protoData []byte
switch proto {
switch request.Proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
@ -205,7 +200,7 @@ func (m *Manager) AddFiltering(
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
return nil, fmt.Errorf("unsupported protocol: %s", request.Proto)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
@ -225,7 +220,7 @@ func (m *Manager) AddFiltering(
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
if request.Direction == fw.RuleDirectionOUT {
addrOffset += addrLen
}
@ -250,14 +245,14 @@ func (m *Manager) AddFiltering(
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetName: request.IPSetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
if request.SrcPort != nil && len(request.SrcPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
@ -268,12 +263,12 @@ func (m *Manager) AddFiltering(
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
Data: encodePort(*request.SrcPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 {
if request.DstPort != nil && len(request.DstPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
@ -284,18 +279,18 @@ func (m *Manager) AddFiltering(
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
Data: encodePort(*request.DstPort),
},
)
}
if action == fw.ActionAccept {
if request.Action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
userData := []byte(strings.Join([]string{rulesetID, request.Comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
Table: table,
@ -309,7 +304,11 @@ func (m *Manager) AddFiltering(
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
r, err := m.rulesetManager.addRule(ruleset, rawIP)
if err != nil {
return nil, fmt.Errorf("add rule: %v", err)
}
return []fw.Rule{r}, nil
}
// getRulesetID returns ruleset ID based on given parameters