added the init command

This commit is contained in:
k4yt3x 2021-05-29 18:42:39 +00:00
parent 08e51c97eb
commit 4a2e1ebf78
2 changed files with 66 additions and 31 deletions

View File

@ -62,7 +62,7 @@ PEER_OPTIONAL_ATTRIBUTES = [
] ]
KEY_TYPE = { KEY_TYPE = {
"name": str, "Name": str,
"Address": list, "Address": list,
"Endpoint": str, "Endpoint": str,
"AllowedIPs": list, "AllowedIPs": list,
@ -87,6 +87,35 @@ class DatabaseManager:
self.database_template = {"peers": {}} self.database_template = {"peers": {}}
self.wireguard = WireGuard() self.wireguard = WireGuard()
def init(self):
"""initialize an empty database file"""
if not self.database_path.exists():
with self.database_path.open(mode="w", encoding="utf-8") as database_file:
writer = csv.DictWriter(
database_file, KEY_TYPE.keys(), quoting=csv.QUOTE_ALL
)
writer.writeheader()
print(f"Empty database file {self.database_path} has been created")
else:
database = self.read_database()
# check values that cannot be generated automatically
for key in ["Address", "Endpoint"]:
for peer in database["peers"]:
if database["peers"][peer].get(key) is None:
print(f"The value of {key} cannot be automatically generated")
sys.exit(1)
# automatically generate missing values
for peer in database["peers"]:
if database["peers"][peer].get("ListenPort") is None:
database["peers"][peer]["ListenPort"] = 51820
if database["peers"][peer].get("PrivateKey") is None:
privatekey = self.wireguard.genkey()
database["peers"][peer]["PrivateKey"] = privatekey
self.write_database(database)
def read_database(self): def read_database(self):
"""read database file into dict """read database file into dict
@ -110,7 +139,7 @@ class DatabaseManager:
peer[key] = int(peer[key]) peer[key] = int(peer[key])
elif KEY_TYPE[key] == bool: elif KEY_TYPE[key] == bool:
peer[key] = peer[key].lower() == "true" peer[key] = peer[key].lower() == "true"
database["peers"][peer.pop("name")] = peer database["peers"][peer.pop("Name")] = peer
return database return database
@ -128,7 +157,7 @@ class DatabaseManager:
writer.writeheader() writer.writeheader()
data = copy.deepcopy(data) data = copy.deepcopy(data)
for peer in data["peers"]: for peer in data["peers"]:
data["peers"][peer]["name"] = peer data["peers"][peer]["Name"] = peer
for key in data["peers"][peer]: for key in data["peers"][peer]:
if isinstance(data["peers"][peer][key], list): if isinstance(data["peers"][peer][key], list):
data["peers"][peer][key] = ",".join(data["peers"][peer][key]) data["peers"][peer][key] = ",".join(data["peers"][peer][key])
@ -140,7 +169,7 @@ class DatabaseManager:
def addpeer( def addpeer(
self, self,
name: str, Name: str,
Address: list, Address: list,
Endpoint: str = None, Endpoint: str = None,
AllowedIPs: list = None, AllowedIPs: list = None,
@ -159,26 +188,26 @@ class DatabaseManager:
): ):
database = self.read_database() database = self.read_database()
if name in database["peers"]: if Name in database["peers"]:
print(f"Peer with name {name} already exists") print(f"Peer with name {Name} already exists")
return return
database["peers"][name] = {} database["peers"][Name] = {}
# if private key is not specified, generate one # if private key is not specified, generate one
if locals().get("PrivateKey") is None: if locals().get("PrivateKey") is None:
privatekey = self.wireguard.genkey() privatekey = self.wireguard.genkey()
database["peers"][name]["PrivateKey"] = privatekey database["peers"][Name]["PrivateKey"] = privatekey
for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES:
if locals().get(key) is not None: if locals().get(key) is not None:
database["peers"][name][key] = locals().get(key) database["peers"][Name][key] = locals().get(key)
self.write_database(database) self.write_database(database)
def updatepeer( def updatepeer(
self, self,
name: str, Name: str,
Address: list = None, Address: list = None,
Endpoint: str = None, Endpoint: str = None,
AllowedIPs: list = None, AllowedIPs: list = None,
@ -197,44 +226,44 @@ class DatabaseManager:
): ):
database = self.read_database() database = self.read_database()
if name not in database["peers"]: if Name not in database["peers"]:
print(f"Peer with name {name} does not exist") print(f"Peer with name {Name} does not exist")
return return
for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES:
if locals().get(key) is not None: if locals().get(key) is not None:
database["peers"][name][key] = locals().get(key) database["peers"][Name][key] = locals().get(key)
self.write_database(database) self.write_database(database)
def delpeer(self, name: str): def delpeer(self, Name: str):
database = self.read_database() database = self.read_database()
# abort if user doesn't exist # abort if user doesn't exist
if name not in database["peers"]: if Name not in database["peers"]:
print(f"Peer with ID {name} does not exist") print(f"Peer with ID {Name} does not exist")
return return
database["peers"].pop(name, None) database["peers"].pop(Name, None)
# write changes into database # write changes into database
self.write_database(database) self.write_database(database)
def showpeers(self, name: str, style: str = "table", simplify: bool = False): def showpeers(self, Name: str, style: str = "table", simplify: bool = False):
database = self.read_database() database = self.read_database()
# if name is specified, show the specified peer # if name is specified, show the specified peer
if name is not None: if Name is not None:
if name not in database["peers"]: if Name not in database["peers"]:
print(f"Peer with ID {name} does not exist") print(f"Peer with ID {Name} does not exist")
return return
peers = [name] peers = [Name]
# otherwise, show all peers # otherwise, show all peers
else: else:
peers = [p for p in database["peers"]] peers = [p for p in database["peers"]]
field_names = ["name"] field_names = ["Name"]
# exclude all columns that only have None's in simplified mode # exclude all columns that only have None's in simplified mode
if simplify is True: if simplify is True:
@ -264,14 +293,14 @@ class DatabaseManager:
database["peers"][peer].get(k) database["peers"][peer].get(k)
if not isinstance(database["peers"][peer].get(k), list) if not isinstance(database["peers"][peer].get(k), list)
else ",".join(database["peers"][peer].get(k)) else ",".join(database["peers"][peer].get(k))
for k in [i for i in table.field_names if i != "name"] for k in [i for i in table.field_names if i != "Name"]
] ]
) )
print(table) print(table)
except NameError: except NameError:
print("PrettyTable is not installed", sys.stderr) print("PrettyTable is not installed", file=sys.stderr)
print("Displaying in table mode is not available", sys.stderr) print("Displaying in table mode is not available", file=sys.stderr)
sys.exit(1) sys.exit(1)
# if the style is text # if the style is text
@ -289,12 +318,12 @@ class DatabaseManager:
) )
print() print()
def genconfig(self, name: str, output: pathlib.Path): def genconfig(self, Name: str, output: pathlib.Path):
database = self.read_database() database = self.read_database()
# check if peer ID is specified # check if peer ID is specified
if name is not None: if Name is not None:
peers = [name] peers = [Name]
else: else:
peers = [p for p in database["peers"]] peers = [p for p in database["peers"]]
@ -307,7 +336,7 @@ class DatabaseManager:
) )
raise FileExistsError raise FileExistsError
elif not output.exists(): elif not output.exists():
print(f"Creating output directory: {output}", sys.stderr) print(f"Creating output directory: {output}", file=sys.stderr)
output.mkdir(exist_ok=True) output.mkdir(exist_ok=True)
# for every peer in the database # for every peer in the database

View File

@ -37,6 +37,9 @@ def parse_arguments():
# add subparsers for commands # add subparsers for commands
subparsers = parser.add_subparsers(dest="command") subparsers = parser.add_subparsers(dest="command")
# initialize empty database
subparsers.add_parser("init")
# add new peer # add new peer
addpeer = subparsers.add_parser("addpeer") addpeer = subparsers.add_parser("addpeer")
addpeer.add_argument("name", help="Name used to identify this node") addpeer.add_argument("name", help="Name used to identify this node")
@ -145,7 +148,10 @@ def main():
database_manager = DatabaseManager(args.database) database_manager = DatabaseManager(args.database)
if args.command == "addpeer": if args.command == "init":
database_manager.init()
elif args.command == "addpeer":
database_manager.addpeer( database_manager.addpeer(
args.name, args.name,
args.address, args.address,