mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-13 17:07:30 +02:00
refactor add filter acl
This commit is contained in:
@ -96,16 +96,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
func (m *Manager) AddFiltering(request fw.RuleRequest) ([]fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@ -116,16 +107,16 @@ func (m *Manager) AddFiltering(
|
||||
chain *nftables.Chain
|
||||
)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
if request.Direction == fw.RuleDirectionOUT {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
request.IP,
|
||||
FilterOutputChainName,
|
||||
nftables.ChainHookOutput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
} else {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
request.IP,
|
||||
FilterInputChainName,
|
||||
nftables.ChainHookInput,
|
||||
nftables.ChainPriorityFilter,
|
||||
@ -135,22 +126,22 @@ func (m *Manager) AddFiltering(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
rawIP := request.IP.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
rawIP = request.IP.To16()
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
rulesetID := m.getRulesetID(request.IP, request.Proto, request.SrcPort, request.DstPort, request.Direction, request.Action, request.IPSetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
if request.IPSetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||
ipset, err = m.rConn.GetSetByName(table, request.IPSetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, request.IPSetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
@ -168,7 +159,11 @@ func (m *Manager) AddFiltering(
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
r, err := m.rulesetManager.addRule(ruleset, rawIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add rule: %v", err)
|
||||
}
|
||||
return []fw.Rule{r}, nil
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
@ -176,7 +171,7 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
if request.Direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
@ -188,7 +183,7 @@ func (m *Manager) AddFiltering(
|
||||
},
|
||||
}
|
||||
|
||||
if proto != "all" {
|
||||
if request.Proto != "all" {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@ -197,7 +192,7 @@ func (m *Manager) AddFiltering(
|
||||
})
|
||||
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
switch request.Proto {
|
||||
case fw.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case fw.ProtocolUDP:
|
||||
@ -205,7 +200,7 @@ func (m *Manager) AddFiltering(
|
||||
case fw.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", request.Proto)
|
||||
}
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
@ -225,7 +220,7 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
if request.Direction == fw.RuleDirectionOUT {
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
@ -250,14 +245,14 @@ func (m *Manager) AddFiltering(
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetName: request.IPSetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
if request.SrcPort != nil && len(request.SrcPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@ -268,12 +263,12 @@ func (m *Manager) AddFiltering(
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
Data: encodePort(*request.SrcPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
if request.DstPort != nil && len(request.DstPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@ -284,18 +279,18 @@ func (m *Manager) AddFiltering(
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
Data: encodePort(*request.DstPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if action == fw.ActionAccept {
|
||||
if request.Action == fw.ActionAccept {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
} else {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
userData := []byte(strings.Join([]string{rulesetID, request.Comment}, " "))
|
||||
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
@ -309,7 +304,11 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
r, err := m.rulesetManager.addRule(ruleset, rawIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add rule: %v", err)
|
||||
}
|
||||
return []fw.Rule{r}, nil
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
|
Reference in New Issue
Block a user