mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-19 06:52:04 +02:00
Hypernetwork support (#619)
* Update README.md * Update README.md * Make on_sd_start.sh executable * Merge pull request #542 from patriceac/patch-1 Fix restoration of model and VAE * Merge pull request #541 from patriceac/patch-2 Fix restoration of parallel output setting * Hypernetwork support Adds support for hypernetworks. Hypernetworks are stored in /models/hypernetworks * forgot to remove unused code Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
committed by
GitHub
parent
1283c6483d
commit
cbe91251ac
@@ -23,6 +23,8 @@ class Request:
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = 1
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
@@ -38,6 +40,7 @@ class Request:
|
||||
"num_outputs": self.num_outputs,
|
||||
"num_inference_steps": self.num_inference_steps,
|
||||
"guidance_scale": self.guidance_scale,
|
||||
"hypernetwork_strengtgh": self.guidance_scale,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"seed": self.seed,
|
||||
@@ -47,6 +50,8 @@ class Request:
|
||||
"use_upscale": self.use_upscale,
|
||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||
"use_vae_model": self.use_vae_model,
|
||||
"use_hypernetwork_model": self.use_hypernetwork_model,
|
||||
"hypernetwork_strength": self.hypernetwork_strength,
|
||||
"output_format": self.output_format,
|
||||
"output_quality": self.output_quality,
|
||||
}
|
||||
@@ -70,6 +75,8 @@ class Request:
|
||||
use_upscale: {self.use_upscale}
|
||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||
use_vae_model: {self.use_vae_model}
|
||||
use_hypernetwork_model: {self.use_hypernetwork_model}
|
||||
hypernetwork_strength: {self.hypernetwork_strength}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
output_format: {self.output_format}
|
||||
output_quality: {self.output_quality}
|
||||
|
198
ui/sd_internal/hypernetwork.py
Normal file
198
ui/sd_internal/hypernetwork.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
|
||||
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
|
||||
|
||||
import inspect
|
||||
import torch
|
||||
import optimizedSD.splitAttention
|
||||
from . import runtime
|
||||
from einops import rearrange
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
loaded_hypernetwork = None
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 0.5
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add an activation func except last layer
|
||||
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
self.fix_old_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
self.to(runtime.thread_data.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
|
||||
for fr, to in changes.items():
|
||||
x = state_dict.get(fr, None)
|
||||
if x is None:
|
||||
continue
|
||||
|
||||
del state_dict[fr]
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
return context_k, context_v
|
||||
|
||||
def get_kv(context, hypernetwork):
|
||||
if hypernetwork is None:
|
||||
return context, context
|
||||
else:
|
||||
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
|
||||
|
||||
# This might need updating as the optimisedSD code changes
|
||||
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
|
||||
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
|
||||
def new_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
# default context
|
||||
context = context if context is not None else x() if inspect.isfunction(x) else x
|
||||
# hypernetwork!
|
||||
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = self.att_step
|
||||
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range (0, limit, att_step):
|
||||
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i+att_step,:,:] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
def load_hypernetwork(path: str):
|
||||
|
||||
state_dict = torch.load(path, map_location='cpu')
|
||||
|
||||
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
activation_func = state_dict.get('activation_func', None)
|
||||
weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
use_dropout = state_dict.get('use_dropout', False)
|
||||
activate_output = state_dict.get('activate_output', True)
|
||||
last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
|
||||
# print(f"layer_structure: {layer_structure}")
|
||||
# print(f"activation_func: {activation_func}")
|
||||
# print(f"weight_init: {weight_init}")
|
||||
# print(f"add_layer_norm: {add_layer_norm}")
|
||||
# print(f"use_dropout: {use_dropout}")
|
||||
# print(f"activate_output: {activate_output}")
|
||||
# print(f"last_layer_dropout: {last_layer_dropout}")
|
||||
|
||||
layers = {}
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
layers[size] = (
|
||||
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||
)
|
||||
print(f"hypernetwork loaded")
|
||||
return layers
|
||||
|
||||
|
||||
|
||||
# overriding of original function
|
||||
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
|
||||
# hijacks the cross attention forward function to add hyper network support
|
||||
def hijack_cross_attention():
|
||||
print("hypernetwork functionality added to cross attention")
|
||||
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
|
||||
# there was a cop on board
|
||||
def unhijack_cross_attention_forward():
|
||||
print("hypernetwork functionality removed from cross attention")
|
||||
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
|
||||
|
||||
hijack_cross_attention()
|
@@ -28,6 +28,9 @@ from gfpgan import GFPGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from . import hypernetwork
|
||||
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
|
||||
|
||||
from threading import Lock
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -57,12 +60,15 @@ def thread_init(device):
|
||||
|
||||
thread_data.ckpt_file = None
|
||||
thread_data.vae_file = None
|
||||
thread_data.hypernetwork_file = None
|
||||
thread_data.gfpgan_file = None
|
||||
thread_data.real_esrgan_file = None
|
||||
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
thread_data.hypernetwork = None
|
||||
thread_data.hypernetwork_strength = 1
|
||||
thread_data.model_gfpgan = None
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
@@ -72,6 +78,8 @@ def thread_init(device):
|
||||
thread_data.device_name = None
|
||||
thread_data.unet_bs = 1
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.sampler_plms = None
|
||||
thread_data.sampler_ddim = None
|
||||
|
||||
thread_data.turbo = False
|
||||
thread_data.force_full_precision = False
|
||||
@@ -433,6 +441,49 @@ def reload_model():
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
def is_hypernetwork_reload_necessary(req: Request):
|
||||
needs_model_reload = False
|
||||
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
|
||||
thread_data.hypernetwork_file = req.use_hypernetwork_model
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
||||
def load_hypernetwork():
|
||||
if thread_data.hypernetwork_file is not None:
|
||||
try:
|
||||
loaded = False
|
||||
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
|
||||
if os.path.exists(thread_data.hypernetwork_file + model_extension):
|
||||
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
|
||||
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
|
||||
loaded = True
|
||||
break
|
||||
|
||||
if not loaded:
|
||||
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
|
||||
thread_data.hypernetwork_file = None
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
|
||||
thread_data.hypernetwork_file = None
|
||||
|
||||
def unload_hypernetwork():
|
||||
if thread_data.hypernetwork is not None:
|
||||
print('Unloading hypernetwork...')
|
||||
if thread_data.device != 'cpu':
|
||||
for i in thread_data.hypernetwork:
|
||||
thread_data.hypernetwork[i][0].to('cpu')
|
||||
thread_data.hypernetwork[i][1].to('cpu')
|
||||
del thread_data.hypernetwork
|
||||
thread_data.hypernetwork = None
|
||||
|
||||
gc()
|
||||
|
||||
def reload_hypernetwork():
|
||||
unload_hypernetwork()
|
||||
load_hypernetwork()
|
||||
|
||||
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
||||
@@ -509,6 +560,7 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
|
||||
res = Response()
|
||||
res.request = req
|
||||
res.images = []
|
||||
thread_data.hypernetwork_strength = req.hypernetwork_strength
|
||||
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
@@ -751,6 +803,8 @@ Sampler: {req.sampler}
|
||||
Negative Prompt: {req.negative_prompt}
|
||||
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
VAE model: {req.use_vae_model}
|
||||
Hypernetwork Model: {req.use_hypernetwork_model}
|
||||
Hypernetwork Strength: {req.hypernetwork_strength}
|
||||
'''
|
||||
try:
|
||||
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
||||
|
@@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = None
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
@@ -177,28 +179,35 @@ current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
current_vae_path = None
|
||||
current_hypernetwork_path = None
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
default_hypernetwork_to_load = None
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||
if ckpt_file_path == None:
|
||||
ckpt_file_path = default_model_to_load
|
||||
if vae_file_path == None:
|
||||
vae_file_path = default_vae_to_load
|
||||
if hypernetwork_file_path == None:
|
||||
hypernetwork_file_path = default_hypernetwork_to_load
|
||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||
return
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from . import runtime
|
||||
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||
runtime.thread_data.vae_file = vae_file_path
|
||||
runtime.load_model_ckpt()
|
||||
runtime.load_hypernetwork()
|
||||
current_model_path = ckpt_file_path
|
||||
current_vae_path = vae_file_path
|
||||
current_hypernetwork_path = hypernetwork_file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
@@ -240,7 +249,7 @@ def thread_get_next_task():
|
||||
manager_lock.release()
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path
|
||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||
from . import runtime
|
||||
try:
|
||||
runtime.thread_init(device)
|
||||
@@ -285,6 +294,10 @@ def thread_render(device):
|
||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
if runtime.is_hypernetwork_reload_necessary(task.request):
|
||||
runtime.reload_hypernetwork()
|
||||
current_hypernetwork_path = task.request.use_hypernetwork_model
|
||||
|
||||
if runtime.is_model_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime.reload_model()
|
||||
@@ -504,6 +517,8 @@ def render(req : ImageRequest):
|
||||
r.use_face_correction = req.use_face_correction
|
||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||
r.use_vae_model = req.use_vae_model
|
||||
r.use_hypernetwork_model = req.use_hypernetwork_model
|
||||
r.hypernetwork_strength = req.hypernetwork_strength
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
r.output_format = req.output_format
|
||||
r.output_quality = req.output_quality
|
||||
|
Reference in New Issue
Block a user