Pass group to firewall

This commit is contained in:
Fata Nugraha 2023-08-04 19:50:57 +07:00 committed by Brian May
parent 755e522eff
commit 7c140daf07
5 changed files with 29 additions and 21 deletions

View File

@ -319,7 +319,7 @@ class FirewallClient:
def setup(self, subnets_include, subnets_exclude, nslist, def setup(self, subnets_include, subnets_exclude, nslist,
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4, udp, redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4, udp,
user, tmark): user, group, tmark):
self.subnets_include = subnets_include self.subnets_include = subnets_include
self.subnets_exclude = subnets_exclude self.subnets_exclude = subnets_exclude
self.nslist = nslist self.nslist = nslist
@ -329,6 +329,7 @@ class FirewallClient:
self.dnsport_v4 = dnsport_v4 self.dnsport_v4 = dnsport_v4
self.udp = udp self.udp = udp
self.user = user self.user = user
self.group = group
self.tmark = tmark self.tmark = tmark
def check(self): def check(self):
@ -367,9 +368,14 @@ class FirewallClient:
user = bytes(self.user, 'utf-8') user = bytes(self.user, 'utf-8')
else: else:
user = b'%d' % self.user user = b'%d' % self.user
if self.group is None:
self.pfile.write(b'GO %d %s %s %d\n' % group = b'-'
(udp, user, bytes(self.tmark, 'ascii'), os.getpid())) elif isinstance(self.group, str):
group = bytes(self.group, 'utf-8')
else:
group = b'%d' % self.group
self.pfile.write(b'GO %d %s %s %s %d\n' %
(udp, user, group, bytes(self.tmark, 'ascii'), os.getpid()))
self.pfile.flush() self.pfile.flush()
line = self.pfile.readline() line = self.pfile.readline()

View File

@ -270,13 +270,15 @@ def main(method_name, syslog):
_, _, args = line.partition(" ") _, _, args = line.partition(" ")
global sshuttle_pid global sshuttle_pid
udp, user, tmark, sshuttle_pid = args.strip().split(" ", 3) udp, user, group, tmark, sshuttle_pid = args.strip().split(" ", 4)
udp = bool(int(udp)) udp = bool(int(udp))
sshuttle_pid = int(sshuttle_pid) sshuttle_pid = int(sshuttle_pid)
if user == '-': if user == '-':
user = None user = None
debug2('Got udp: %r, user: %r, tmark: %s, sshuttle_pid: %d' % if group == '-':
(udp, user, tmark, sshuttle_pid)) group = None
debug2('Got udp: %r, user: %r, group: %r, tmark: %s, sshuttle_pid: %d' %
(udp, user, group, tmark, sshuttle_pid))
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6] subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6] nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
@ -291,14 +293,14 @@ def main(method_name, syslog):
method.setup_firewall( method.setup_firewall(
port_v6, dnsport_v6, nslist_v6, port_v6, dnsport_v6, nslist_v6,
socket.AF_INET6, subnets_v6, udp, socket.AF_INET6, subnets_v6, udp,
user, tmark) user, group, tmark)
if subnets_v4 or nslist_v4: if subnets_v4 or nslist_v4:
debug2('setting up IPv4.') debug2('setting up IPv4.')
method.setup_firewall( method.setup_firewall(
port_v4, dnsport_v4, nslist_v4, port_v4, dnsport_v4, nslist_v4,
socket.AF_INET, subnets_v4, udp, socket.AF_INET, subnets_v4, udp,
user, tmark) user, group, tmark)
flush_systemd_dns_cache() flush_systemd_dns_cache()
stdout.write('STARTED\n') stdout.write('STARTED\n')
@ -334,7 +336,7 @@ def main(method_name, syslog):
try: try:
if subnets_v6 or nslist_v6: if subnets_v6 or nslist_v6:
debug2('undoing IPv6 changes.') debug2('undoing IPv6 changes.')
method.restore_firewall(port_v6, socket.AF_INET6, udp, user) method.restore_firewall(port_v6, socket.AF_INET6, udp, user, group)
except Exception: except Exception:
try: try:
debug1("Error trying to undo IPv6 firewall.") debug1("Error trying to undo IPv6 firewall.")
@ -345,7 +347,7 @@ def main(method_name, syslog):
try: try:
if subnets_v4 or nslist_v4: if subnets_v4 or nslist_v4:
debug2('undoing IPv4 changes.') debug2('undoing IPv4 changes.')
method.restore_firewall(port_v4, socket.AF_INET, udp, user) method.restore_firewall(port_v4, socket.AF_INET, udp, user, group)
except Exception: except Exception:
try: try:
debug1("Error trying to undo IPv4 firewall.") debug1("Error trying to undo IPv4 firewall.")

View File

@ -90,10 +90,10 @@ class BaseMethod(object):
(key, self.name)) (key, self.name))
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
user, tmark): user, group, tmark):
raise NotImplementedError() raise NotImplementedError()
def restore_firewall(self, port, family, udp, user): def restore_firewall(self, port, family, udp, user, group):
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod

View File

@ -31,17 +31,17 @@ class Method(BaseMethod):
chain = 'sshuttle-%s' % port chain = 'sshuttle-%s' % port
# basic cleanup/setup of chains # basic cleanup/setup of chains
self.restore_firewall(port, family, udp, user) self.restore_firewall(port, family, udp, user, group)
_ipt('-N', chain) _ipt('-N', chain)
_ipt('-F', chain) _ipt('-F', chain)
if user is not None or group is not None: if user is not None or group is not None:
margs = ['-I', 'OUTPUT', '1', '-m', 'owner'] margs = ['-I', 'OUTPUT', '1', '-m', 'owner']
if user is not None: if user is not None:
margs.append('--uid-owner', str(user)) margs += ['--uid-owner', str(user)]
if group is not None: if group is not None:
margs.append('--gid-owner', str(group)) margs += ['--gid-owner', str(group)]
margs = args.append('-j', 'MARK', '--set-mark', str(port)) margs += ['-j', 'MARK', '--set-mark', str(port)]
nonfatal(_ipm, *margs) nonfatal(_ipm, *margs)
args = '-m', 'mark', '--mark', str(port), '-j', chain args = '-m', 'mark', '--mark', str(port), '-j', chain
else: else:
@ -104,10 +104,10 @@ class Method(BaseMethod):
if user is not None or group is not None: if user is not None or group is not None:
margs = ['-D', 'OUTPUT', '-m', 'owner'] margs = ['-D', 'OUTPUT', '-m', 'owner']
if user is not None: if user is not None:
margs.append('--uid-owner', str(user)) margs += ['--uid-owner', str(user)]
if group is not None: if group is not None:
margs.append('--gid-owner', str(group)) margs += ['--gid-owner', str(group)]
margs = args.append('-j', 'MARK', '--set-mark', str(port)) margs += ['-j', 'MARK', '--set-mark', str(port)]
nonfatal(_ipm, *margs) nonfatal(_ipm, *margs)
args = '-m', 'mark', '--mark', str(port), '-j', chain args = '-m', 'mark', '--mark', str(port), '-j', chain

View File

@ -134,7 +134,7 @@ class Method(BaseMethod):
divert_chain = 'sshuttle-d-%s' % port divert_chain = 'sshuttle-d-%s' % port
# basic cleanup/setup of chains # basic cleanup/setup of chains
self.restore_firewall(port, family, udp, user) self.restore_firewall(port, family, udp, user, group)
_ipt('-N', mark_chain) _ipt('-N', mark_chain)
_ipt('-F', mark_chain) _ipt('-F', mark_chain)