#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Name: Database Manager Creator: K4YT3X Date Created: July 19, 2020 Last Modified: May 29, 2021 """ # built-in imports import contextlib import copy import csv import pathlib import sys # third party imports with contextlib.suppress(ImportError): from prettytable import PrettyTable # local imports from .wireguard import WireGuard INTERFACE_ATTRIBUTES = [ "Address", "ListenPort", "FwMark", "PrivateKey", "DNS", "MTU", "Table", "PreUp", "PostUp", "PreDown", "PostDown", "SaveConfig", ] INTERFACE_OPTIONAL_ATTRIBUTES = [ "ListenPort", "FwMark", "DNS", "MTU", "Table", "PreUp", "PostUp", "PreDown", "PostDown", "SaveConfig", ] PEER_ATTRIBUTES = [ "PublicKey", "PresharedKey", "AllowedIPs", "Endpoint", "PersistentKeepalive", ] PEER_OPTIONAL_ATTRIBUTES = [ "PersistentKeepalive", ] KEY_TYPE = { "Name": str, "Address": list, "Endpoint": str, "AllowedIPs": list, "ListenPort": int, "PersistentKeepalive": int, "FwMark": str, "PrivateKey": str, "DNS": str, "MTU": int, "Table": str, "PreUp": str, "PostUp": str, "PreDown": str, "PostDown": str, "SaveConfig": bool, } class DatabaseManager: def __init__(self, database_path: pathlib.Path): self.database_path = database_path self.database_template = {"peers": {}} self.wireguard = WireGuard() def init(self): """initialize an empty database file""" if not self.database_path.exists(): with self.database_path.open(mode="w", encoding="utf-8") as database_file: writer = csv.DictWriter( database_file, KEY_TYPE.keys(), quoting=csv.QUOTE_ALL ) writer.writeheader() print(f"Empty database file {self.database_path} has been created") else: database = self.read_database() # check values that cannot be generated automatically for key in ["Address", "Endpoint"]: for peer in database["peers"]: if database["peers"][peer].get(key) is None: print(f"The value of {key} cannot be automatically generated") sys.exit(1) # automatically generate missing values for peer in database["peers"]: if database["peers"][peer].get("ListenPort") is None: database["peers"][peer]["ListenPort"] = 51820 if database["peers"][peer].get("PrivateKey") is None: privatekey = self.wireguard.genkey() database["peers"][peer]["PrivateKey"] = privatekey self.write_database(database) def read_database(self): """read database file into dict Returns: dict: content of database file in dict format """ if not self.database_path.is_file(): return self.database_template database = copy.deepcopy(self.database_template) with self.database_path.open(mode="r", encoding="utf-8") as database_file: peers = csv.DictReader(database_file) for peer in peers: for key in peer: if peer[key] == "": peer[key] = None elif KEY_TYPE[key] == list: peer[key] = peer[key].split(",") elif KEY_TYPE[key] == int: peer[key] = int(peer[key]) elif KEY_TYPE[key] == bool: peer[key] = peer[key].lower() == "true" database["peers"][peer.pop("Name")] = peer return database def write_database(self, data: dict): """dump data into database file Args: data (dict): content of database """ with self.database_path.open(mode="w", encoding="utf-8") as database_file: writer = csv.DictWriter( database_file, KEY_TYPE.keys(), quoting=csv.QUOTE_ALL ) writer.writeheader() data = copy.deepcopy(data) for peer in data["peers"]: data["peers"][peer]["Name"] = peer for key in data["peers"][peer]: if isinstance(data["peers"][peer][key], list): data["peers"][peer][key] = ",".join(data["peers"][peer][key]) elif isinstance(data["peers"][peer][key], int): data["peers"][peer][key] = str(data["peers"][peer][key]) elif isinstance(data["peers"][peer][key], bool): data["peers"][peer][key] = str(data["peers"][peer][key]) writer.writerow(data["peers"][peer]) def addpeer( self, Name: str, Address: list, Endpoint: str = None, AllowedIPs: list = None, ListenPort: int = None, PersistentKeepalive: int = None, FwMark: str = None, PrivateKey: str = None, DNS: str = None, MTU: int = None, Table: str = None, PreUp: str = None, PostUp: str = None, PreDown: str = None, PostDown: str = None, SaveConfig: bool = None, ): database = self.read_database() if Name in database["peers"]: print(f"Peer with name {Name} already exists") return database["peers"][Name] = {} # if private key is not specified, generate one if locals().get("PrivateKey") is None: privatekey = self.wireguard.genkey() database["peers"][Name]["PrivateKey"] = privatekey for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: if locals().get(key) is not None: database["peers"][Name][key] = locals().get(key) self.write_database(database) def updatepeer( self, Name: str, Address: list = None, Endpoint: str = None, AllowedIPs: list = None, ListenPort: int = None, PersistentKeepalive: int = None, FwMark: str = None, PrivateKey: str = None, DNS: str = None, MTU: int = None, Table: str = None, PreUp: str = None, PostUp: str = None, PreDown: str = None, PostDown: str = None, SaveConfig: bool = None, ): database = self.read_database() if Name not in database["peers"]: print(f"Peer with name {Name} does not exist") return for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: if locals().get(key) is not None: database["peers"][Name][key] = locals().get(key) self.write_database(database) def delpeer(self, Name: str): database = self.read_database() # abort if user doesn't exist if Name not in database["peers"]: print(f"Peer with ID {Name} does not exist") return database["peers"].pop(Name, None) # write changes into database self.write_database(database) def showpeers(self, Name: str, style: str = "table", simplify: bool = False): database = self.read_database() # if name is specified, show the specified peer if Name is not None: if Name not in database["peers"]: print(f"Peer with ID {Name} does not exist") return peers = [Name] # otherwise, show all peers else: peers = [p for p in database["peers"]] field_names = ["Name"] # exclude all columns that only have None's in simplified mode if simplify is True: for peer in peers: for key in INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES: if ( database["peers"][peer].get(key) is not None and key not in field_names ): field_names.append(key) # include all columns by default else: field_names += INTERFACE_ATTRIBUTES + PEER_ATTRIBUTES # if the style is table # print with prettytable if style == "table": try: table = PrettyTable() table.field_names = field_names for peer in peers: table.add_row( [peer] + [ database["peers"][peer].get(k) if not isinstance(database["peers"][peer].get(k), list) else ",".join(database["peers"][peer].get(k)) for k in [i for i in table.field_names if i != "Name"] ] ) print(table) except NameError: print("PrettyTable is not installed", file=sys.stderr) print("Displaying in table mode is not available", file=sys.stderr) sys.exit(1) # if the style is text # print in plaintext format elif style == "text": for peer in peers: print(f"{'peer': <14}{peer}") for key in field_names: print( f"{key: <14}{database['peers'][peer].get(key)}" ) if not isinstance( database["peers"][peer].get(key), list ) else print( f"{key: <14}{','.join(database['peers'][peer].get(key))}" ) print() def genconfig(self, Name: str, output: pathlib.Path): database = self.read_database() # check if peer ID is specified if Name is not None: peers = [Name] else: peers = [p for p in database["peers"]] # check if output directory is valid # create output directory if it does not exist if output.exists() and not output.is_dir(): print( "Error: output path already exists and is not a directory", file=sys.stderr, ) raise FileExistsError elif not output.exists(): print(f"Creating output directory: {output}", file=sys.stderr) output.mkdir(exist_ok=True) # for every peer in the database for peer in peers: with (output / f"{peer}.conf").open("w") as config: config.write("[Interface]\n") config.write("# Name: {}\n".format(peer)) config.write( "Address = {}\n".format( ", ".join(database["peers"][peer]["Address"]) ) ) config.write( "PrivateKey = {}\n".format(database["peers"][peer]["PrivateKey"]) ) for key in INTERFACE_OPTIONAL_ATTRIBUTES: if database["peers"][peer].get(key) is not None: config.write( "{} = {}\n".format(key, database["peers"][peer][key]) ) # generate [Peer] sections for all other peers for p in [i for i in database["peers"] if i != peer]: config.write("\n[Peer]\n") config.write("# Name: {}\n".format(p)) config.write( "PublicKey = {}\n".format( self.wireguard.pubkey(database["peers"][p]["PrivateKey"]) ) ) if database["peers"][p].get("Endpoint") is not None: config.write( "Endpoint = {}:{}\n".format( database["peers"][p]["Endpoint"], database["peers"][p]["ListenPort"], ) ) if database["peers"][p].get("Address") is not None: if database["peers"][p].get("AllowedIPs") is not None: allowed_ips = ", ".join( database["peers"][p]["Address"] + database["peers"][p]["AllowedIPs"] ) else: allowed_ips = ", ".join(database["peers"][p]["Address"]) config.write("AllowedIPs = {}\n".format(allowed_ips)) for key in PEER_OPTIONAL_ATTRIBUTES: if database["peers"][p].get(key) is not None: config.write( "{} = {}\n".format(key, database["peers"][p][key]) )