From 27c61132871c21d3c7c30197c5f6bc14429ec82e Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 13:29:06 +0530 Subject: [PATCH] Support hypernetworks; moves the hypernetwork module to diffusion-kit --- ui/sd_internal/hypernetwork.py | 198 -------------------------------- ui/sd_internal/model_manager.py | 1 - ui/sd_internal/runtime2.py | 11 +- 3 files changed, 8 insertions(+), 202 deletions(-) delete mode 100644 ui/sd_internal/hypernetwork.py diff --git a/ui/sd_internal/hypernetwork.py b/ui/sd_internal/hypernetwork.py deleted file mode 100644 index 979a74f3..00000000 --- a/ui/sd_internal/hypernetwork.py +++ /dev/null @@ -1,198 +0,0 @@ -# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity -# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff. - -import inspect -import torch -import optimizedSD.splitAttention -from . import runtime -from einops import rearrange - -optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} - -loaded_hypernetwork = None - -class HypernetworkModule(torch.nn.Module): - multiplier = 0.5 - activation_dict = { - "linear": torch.nn.Identity, - "relu": torch.nn.ReLU, - "leakyrelu": torch.nn.LeakyReLU, - "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish, - "tanh": torch.nn.Tanh, - "sigmoid": torch.nn.Sigmoid, - } - activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): - super().__init__() - - assert layer_structure is not None, "layer_structure must not be None" - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - - linears = [] - for i in range(len(layer_structure) - 1): - - # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - - # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): - pass - elif activation_func in self.activation_dict: - linears.append(self.activation_dict[activation_func]()) - else: - raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - - # Add layer normalization - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) - - self.linear = torch.nn.Sequential(*linears) - - self.fix_old_state_dict(state_dict) - self.load_state_dict(state_dict) - - self.to(runtime.thread_data.device) - - def fix_old_state_dict(self, state_dict): - changes = { - 'linear1.bias': 'linear.0.bias', - 'linear1.weight': 'linear.0.weight', - 'linear2.bias': 'linear.1.bias', - 'linear2.weight': 'linear.1.weight', - } - - for fr, to in changes.items(): - x = state_dict.get(fr, None) - if x is None: - continue - - del state_dict[fr] - state_dict[to] = x - - def forward(self, x: torch.Tensor): - return x + self.linear(x) * runtime.thread_data.hypernetwork_strength - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = hypernetwork.get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - -def get_kv(context, hypernetwork): - if hypernetwork is None: - return context, context - else: - return apply_hypernetwork(runtime.thread_data.hypernetwork, context) - -# This might need updating as the optimisedSD code changes -# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue -# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171 -def new_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - # default context - context = context if context is not None else x() if inspect.isfunction(x) else x - # hypernetwork! - context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - - limit = k.shape[0] - att_step = self.att_step - q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) - k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) - v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) - - q_chunks.reverse() - k_chunks.reverse() - v_chunks.reverse() - sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - del k, q, v - for i in range (0, limit, att_step): - - q_buffer = q_chunks.pop() - k_buffer = k_chunks.pop() - v_buffer = v_chunks.pop() - sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale - - del k_buffer, q_buffer - # attention, what we cannot get enough of, by chunks - - sim_buffer = sim_buffer.softmax(dim=-1) - - sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) - del v_buffer - sim[i:i+att_step,:,:] = sim_buffer - - del sim_buffer - sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) - - -def load_hypernetwork(path: str): - - state_dict = torch.load(path, map_location='cpu') - - layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - activation_func = state_dict.get('activation_func', None) - weight_init = state_dict.get('weight_initialization', 'Normal') - add_layer_norm = state_dict.get('is_layer_norm', False) - use_dropout = state_dict.get('use_dropout', False) - activate_output = state_dict.get('activate_output', True) - last_layer_dropout = state_dict.get('last_layer_dropout', False) - # this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this - # print(f"layer_structure: {layer_structure}") - # print(f"activation_func: {activation_func}") - # print(f"weight_init: {weight_init}") - # print(f"add_layer_norm: {add_layer_norm}") - # print(f"use_dropout: {use_dropout}") - # print(f"activate_output: {activate_output}") - # print(f"last_layer_dropout: {last_layer_dropout}") - - layers = {} - for size, sd in state_dict.items(): - if type(size) == int: - layers[size] = ( - HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - ) - print(f"hypernetwork loaded") - return layers - - - -# overriding of original function -old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward -# hijacks the cross attention forward function to add hyper network support -def hijack_cross_attention(): - print("hypernetwork functionality added to cross attention") - optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward -# there was a cop on board -def unhijack_cross_attention_forward(): - print("hypernetwork functionality removed from cross attention") - optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward - -hijack_cross_attention() \ No newline at end of file diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index a8b249a2..21acb540 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -70,7 +70,6 @@ def resolve_model_to_use(model_name:str, model_type:str): print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension - print(f'No valid models found for model_name: {model_name}') return None def resolve_sd_model_to_use(model_name:str=None): diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 840cec09..90a868a5 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -41,6 +41,7 @@ def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) + model_loader.unload_hypernetwork_model(thread_data) def init_and_load_default_models(): # init default model paths @@ -60,9 +61,13 @@ def reload_models_if_necessary(req: Request): model_loader.load_sd_model(thread_data) - # if is_hypernetwork_reload_necessary(task.request): - # current_state = ServerStates.LoadingModel - # runtime.reload_hypernetwork() + if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: + thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model + + if thread_data.model_paths['hypernetwork'] is not None: + model_loader.load_hypernetwork_model(thread_data) + else: + model_loader.unload_hypernetwork_model(thread_data) def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)