From 4a2e1ebf789d09e67bedc6a11bc5a097a3d91f83 Mon Sep 17 00:00:00 2001 From: k4yt3x Date: Sat, 29 May 2021 18:42:39 +0000 Subject: [PATCH] added the init command --- wg_meshconf/database_manager.py | 89 ++++++++++++++++++++++----------- wg_meshconf/wg_meshconf.py | 8 ++- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/wg_meshconf/database_manager.py b/wg_meshconf/database_manager.py index bf33649..3fb69a7 100755 --- a/wg_meshconf/database_manager.py +++ b/wg_meshconf/database_manager.py @@ -62,7 +62,7 @@ PEER_OPTIONAL_ATTRIBUTES = [ ] KEY_TYPE = { - "name": str, + "Name": str, "Address": list, "Endpoint": str, "AllowedIPs": list, @@ -87,6 +87,35 @@ class DatabaseManager: self.database_template = {"peers": {}} self.wireguard = WireGuard() + def init(self): + """initialize an empty database file""" + if not self.database_path.exists(): + with self.database_path.open(mode="w", encoding="utf-8") as database_file: + writer = csv.DictWriter( + database_file, KEY_TYPE.keys(), quoting=csv.QUOTE_ALL + ) + writer.writeheader() + print(f"Empty database file {self.database_path} has been created") + else: + database = self.read_database() + + # check values that cannot be generated automatically + for key in ["Address", "Endpoint"]: + for peer in database["peers"]: + if database["peers"][peer].get(key) is None: + print(f"The value of {key} cannot be automatically generated") + sys.exit(1) + + # automatically generate missing values + for peer in database["peers"]: + if database["peers"][peer].get("ListenPort") is None: + database["peers"][peer]["ListenPort"] = 51820 + + if database["peers"][peer].get("PrivateKey") is None: + privatekey = self.wireguard.genkey() + database["peers"][peer]["PrivateKey"] = privatekey + self.write_database(database) + def read_database(self): """read database file into dict @@ -110,7 +139,7 @@ class DatabaseManager: peer[key] = int(peer[key]) elif KEY_TYPE[key] == bool: peer[key] = peer[key].lower() == "true" - database["peers"][peer.pop("name")] = peer + database["peers"][peer.pop("Name")] = peer return database @@ -128,7 +157,7 @@ class DatabaseManager: writer.writeheader() data = copy.deepcopy(data) for peer in data["peers"]: - data["peers"][peer]["name"] = peer + data["peers"][peer]["Name"] = peer for key in data["peers"][peer]: if isinstance(data["peers"][peer][key], list): data["peers"][peer][key] = ",".join(data["peers"][peer][key]) @@ -140,7 +169,7 @@ class DatabaseManager: def addpeer( self, - name: str, + Name: str, Address: list, Endpoint: str = None, AllowedIPs: list = None, @@ -159,26 +188,26 @@ class DatabaseManager: ): database = self.read_database() - if name in database["peers"]: - print(f"Peer with name {name} already exists") + if Name in database["peers"]: + print(f"Peer with name {Name} already exists") return - database["peers"][name] = {} + database["peers"][Name] = {} # if private key is not specified, generate one if locals().get("PrivateKey") is None: privatekey = self.wireguard.genkey() - database["peers"][name]["PrivateKey"] = privatekey + database["peers"][Name]["PrivateKey"] = privatekey for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: if locals().get(key) is not None: - database["peers"][name][key] = locals().get(key) + database["peers"][Name][key] = locals().get(key) self.write_database(database) def updatepeer( self, - name: str, + Name: str, Address: list = None, Endpoint: str = None, AllowedIPs: list = None, @@ -197,44 +226,44 @@ class DatabaseManager: ): database = self.read_database() - if name not in database["peers"]: - print(f"Peer with name {name} does not exist") + if Name not in database["peers"]: + print(f"Peer with name {Name} does not exist") return for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: if locals().get(key) is not None: - database["peers"][name][key] = locals().get(key) + database["peers"][Name][key] = locals().get(key) self.write_database(database) - def delpeer(self, name: str): + def delpeer(self, Name: str): database = self.read_database() # abort if user doesn't exist - if name not in database["peers"]: - print(f"Peer with ID {name} does not exist") + if Name not in database["peers"]: + print(f"Peer with ID {Name} does not exist") return - database["peers"].pop(name, None) + database["peers"].pop(Name, None) # write changes into database self.write_database(database) - def showpeers(self, name: str, style: str = "table", simplify: bool = False): + def showpeers(self, Name: str, style: str = "table", simplify: bool = False): database = self.read_database() # if name is specified, show the specified peer - if name is not None: - if name not in database["peers"]: - print(f"Peer with ID {name} does not exist") + if Name is not None: + if Name not in database["peers"]: + print(f"Peer with ID {Name} does not exist") return - peers = [name] + peers = [Name] # otherwise, show all peers else: peers = [p for p in database["peers"]] - field_names = ["name"] + field_names = ["Name"] # exclude all columns that only have None's in simplified mode if simplify is True: @@ -264,14 +293,14 @@ class DatabaseManager: database["peers"][peer].get(k) if not isinstance(database["peers"][peer].get(k), list) else ",".join(database["peers"][peer].get(k)) - for k in [i for i in table.field_names if i != "name"] + for k in [i for i in table.field_names if i != "Name"] ] ) print(table) except NameError: - print("PrettyTable is not installed", sys.stderr) - print("Displaying in table mode is not available", sys.stderr) + print("PrettyTable is not installed", file=sys.stderr) + print("Displaying in table mode is not available", file=sys.stderr) sys.exit(1) # if the style is text @@ -289,12 +318,12 @@ class DatabaseManager: ) print() - def genconfig(self, name: str, output: pathlib.Path): + def genconfig(self, Name: str, output: pathlib.Path): database = self.read_database() # check if peer ID is specified - if name is not None: - peers = [name] + if Name is not None: + peers = [Name] else: peers = [p for p in database["peers"]] @@ -307,7 +336,7 @@ class DatabaseManager: ) raise FileExistsError elif not output.exists(): - print(f"Creating output directory: {output}", sys.stderr) + print(f"Creating output directory: {output}", file=sys.stderr) output.mkdir(exist_ok=True) # for every peer in the database diff --git a/wg_meshconf/wg_meshconf.py b/wg_meshconf/wg_meshconf.py index 9f16c30..58fad77 100755 --- a/wg_meshconf/wg_meshconf.py +++ b/wg_meshconf/wg_meshconf.py @@ -37,6 +37,9 @@ def parse_arguments(): # add subparsers for commands subparsers = parser.add_subparsers(dest="command") + # initialize empty database + subparsers.add_parser("init") + # add new peer addpeer = subparsers.add_parser("addpeer") addpeer.add_argument("name", help="Name used to identify this node") @@ -145,7 +148,10 @@ def main(): database_manager = DatabaseManager(args.database) - if args.command == "addpeer": + if args.command == "init": + database_manager.init() + + elif args.command == "addpeer": database_manager.addpeer( args.name, args.address,