mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-25 01:34:38 +01:00
Hypernetwork support (#619)
* Update README.md * Update README.md * Make on_sd_start.sh executable * Merge pull request #542 from patriceac/patch-1 Fix restoration of model and VAE * Merge pull request #541 from patriceac/patch-2 Fix restoration of parallel output setting * Hypernetwork support Adds support for hypernetworks. Hypernetworks are stored in /models/hypernetworks * forgot to remove unused code Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
parent
1283c6483d
commit
cbe91251ac
@ -3,6 +3,8 @@
|
||||
|
||||
[![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB) (for support, and development discussion) | [Troubleshooting guide for common problems](Troubleshooting.md)
|
||||
|
||||
New! Experimental support for Stable Diffusion 2.0 is available in beta!
|
||||
|
||||
----
|
||||
|
||||
## Step 1: Download the installer
|
||||
@ -28,7 +30,9 @@ The installer will take care of whatever is needed. A friendly [Discord communit
|
||||
- **No Dependencies or Technical Knowledge Required**: 1-click install for Windows 10/11 and Linux. *No dependencies*, no need for WSL or Docker or Conda or technical setup. Just download and run!
|
||||
- **Clutter-free UI**: a friendly and simple UI, while providing a lot of powerful features
|
||||
- Supports "*Text to Image*" and "*Image to Image*"
|
||||
- **Stable Diffusion 2.0 support (experimental)** - available in beta channel
|
||||
- **Custom Models**: Use your own `.ckpt` file, by placing it inside the `models/stable-diffusion` folder!
|
||||
- **Auto scan for malicious models** - uses picklescan to prevent malicious models
|
||||
- **Live Preview**: See the image as the AI is drawing it
|
||||
- **Task Queue**: Queue up all your ideas, without waiting for the current task to finish
|
||||
- **In-Painting**: Specify areas of your image to paint into
|
||||
|
@ -201,8 +201,10 @@ call WHERE uvicorn > .tmp
|
||||
|
||||
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
|
||||
if not exist "..\models\vae" mkdir "..\models\vae"
|
||||
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
|
||||
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
||||
echo. > "..\models\vae\Put your VAE files here.txt"
|
||||
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
|
||||
|
||||
@if exist "sd-v1-4.ckpt" (
|
||||
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
||||
|
@ -161,8 +161,10 @@ fi
|
||||
|
||||
mkdir -p "../models/stable-diffusion"
|
||||
mkdir -p "../models/vae"
|
||||
mkdir -p "../models/hypernetwork"
|
||||
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
|
||||
echo "" > "../models/vae/Put your VAE files here.txt"
|
||||
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
|
||||
|
||||
if [ -f "sd-v1-4.ckpt" ]; then
|
||||
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
||||
|
@ -131,6 +131,12 @@
|
||||
</select>
|
||||
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">Click to learn more about VAEs</span></i></a>
|
||||
</td></tr>
|
||||
<tr class="pl-5"><td><label for="hypernetwork_model">Hypernetwork:</i></label></td><td>
|
||||
<select id="hypernetwork_model" name="hypernetwork_model">
|
||||
<!-- <option value="" selected>None</option> -->
|
||||
</select>
|
||||
</td></tr>
|
||||
<tr id="hypernetwork_strength_container" class="pl-5"><td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td><td> <input id="hypernetwork_strength_slider" name="hypernetwork_strength_slider" class="editor-slider" value="100" type="range" min="0" max="100"> <input id="hypernetwork_strength" name="hypernetwork_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)"><br/></td></tr></span>
|
||||
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td>
|
||||
<select id="sampler" name="sampler">
|
||||
<option value="plms">plms</option>
|
||||
|
@ -14,12 +14,14 @@ const SETTINGS_IDS_LIST = [
|
||||
"num_outputs_parallel",
|
||||
"stable_diffusion_model",
|
||||
"vae_model",
|
||||
"hypernetwork_model",
|
||||
"sampler",
|
||||
"width",
|
||||
"height",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"prompt_strength",
|
||||
"hypernetwork_strength",
|
||||
"output_format",
|
||||
"output_quality",
|
||||
"negative_prompt",
|
||||
|
@ -35,6 +35,9 @@ let useUpscalingField = document.querySelector("#use_upscale")
|
||||
let upscaleModelField = document.querySelector("#upscale_model")
|
||||
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
||||
let vaeModelField = document.querySelector('#vae_model')
|
||||
let hypernetworkModelField = document.querySelector('#hypernetwork_model')
|
||||
let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider')
|
||||
let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength')
|
||||
let outputFormatField = document.querySelector('#output_format')
|
||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||
@ -654,6 +657,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function onTaskStart(task) {
|
||||
if (!task.isProcessing || task.batchesDone >= task.batchCount) {
|
||||
return
|
||||
@ -750,6 +754,10 @@ function createTask(task) {
|
||||
if (task.reqBody.use_upscale) {
|
||||
taskConfig += `, <b>Upscale:</b> ${task.reqBody.use_upscale}`
|
||||
}
|
||||
if (task.reqBody.use_hypernetwork_model) {
|
||||
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
|
||||
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
|
||||
}
|
||||
|
||||
let taskEntry = document.createElement('div')
|
||||
taskEntry.id = `imageTaskContainer-${Date.now()}`
|
||||
@ -1105,6 +1113,27 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
|
||||
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
||||
updatePromptStrength()
|
||||
|
||||
/********************* Hypernetwork Strength **********************/
|
||||
function updateHypernetworkStrength() {
|
||||
hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100
|
||||
hypernetworkStrengthField.dispatchEvent(new Event("change"))
|
||||
}
|
||||
|
||||
function updateHypernetworkStrengthSlider() {
|
||||
if (hypernetworkStrengthField.value < 0) {
|
||||
hypernetworkStrengthField.value = 0
|
||||
} else if (hypernetworkStrengthField.value > 0.99) {
|
||||
hypernetworkStrengthField.value = 0.99
|
||||
}
|
||||
|
||||
hypernetworkStrengthSlider.value = hypernetworkStrengthField.value * 100
|
||||
hypernetworkStrengthSlider.dispatchEvent(new Event("change"))
|
||||
}
|
||||
|
||||
hypernetworkStrengthSlider.addEventListener('input', updateHypernetworkStrength)
|
||||
hypernetworkStrengthField.addEventListener('input', updateHypernetworkStrengthSlider)
|
||||
updateHypernetworkStrength()
|
||||
|
||||
/********************* JPEG Quality **********************/
|
||||
function updateOutputQuality() {
|
||||
outputQualityField.value = 0 | outputQualitySlider.value
|
||||
@ -1138,8 +1167,10 @@ async function getModels() {
|
||||
try {
|
||||
const sd_model_setting_key = "stable_diffusion_model"
|
||||
const vae_model_setting_key = "vae_model"
|
||||
const hypernetwork_model_key = "hypernetwork_model"
|
||||
const selectedSDModel = SETTINGS[sd_model_setting_key].value
|
||||
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
|
||||
const selectedHypernetworkModel = SETTINGS[hypernetwork_model_key].value
|
||||
|
||||
const models = await SD.getModels()
|
||||
const modelsOptions = models['options']
|
||||
@ -1154,7 +1185,10 @@ async function getModels() {
|
||||
|
||||
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
||||
const vaeOptions = modelsOptions['vae']
|
||||
const hypernetworkOptions = modelOptions['hypernetwork']
|
||||
|
||||
vaeOptions.unshift('') // add a None option
|
||||
hypernetworkOptions.unshift('') // add a None option
|
||||
|
||||
function createModelOptions(modelField, selectedModel) {
|
||||
return function(modelName) {
|
||||
@ -1172,6 +1206,7 @@ async function getModels() {
|
||||
|
||||
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
|
||||
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
|
||||
hypernetworkOptions.forEach(createModelOptions(hypernetworkModelField, selectedHypernetworkModel))
|
||||
|
||||
// TODO: set default for model here too
|
||||
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
||||
|
@ -23,6 +23,8 @@ class Request:
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = 1
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
@ -38,6 +40,7 @@ class Request:
|
||||
"num_outputs": self.num_outputs,
|
||||
"num_inference_steps": self.num_inference_steps,
|
||||
"guidance_scale": self.guidance_scale,
|
||||
"hypernetwork_strengtgh": self.guidance_scale,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"seed": self.seed,
|
||||
@ -47,6 +50,8 @@ class Request:
|
||||
"use_upscale": self.use_upscale,
|
||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||
"use_vae_model": self.use_vae_model,
|
||||
"use_hypernetwork_model": self.use_hypernetwork_model,
|
||||
"hypernetwork_strength": self.hypernetwork_strength,
|
||||
"output_format": self.output_format,
|
||||
"output_quality": self.output_quality,
|
||||
}
|
||||
@ -70,6 +75,8 @@ class Request:
|
||||
use_upscale: {self.use_upscale}
|
||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||
use_vae_model: {self.use_vae_model}
|
||||
use_hypernetwork_model: {self.use_hypernetwork_model}
|
||||
hypernetwork_strength: {self.hypernetwork_strength}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
output_format: {self.output_format}
|
||||
output_quality: {self.output_quality}
|
||||
|
198
ui/sd_internal/hypernetwork.py
Normal file
198
ui/sd_internal/hypernetwork.py
Normal file
@ -0,0 +1,198 @@
|
||||
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
|
||||
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
|
||||
|
||||
import inspect
|
||||
import torch
|
||||
import optimizedSD.splitAttention
|
||||
from . import runtime
|
||||
from einops import rearrange
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
loaded_hypernetwork = None
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 0.5
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add an activation func except last layer
|
||||
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
self.fix_old_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
self.to(runtime.thread_data.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
|
||||
for fr, to in changes.items():
|
||||
x = state_dict.get(fr, None)
|
||||
if x is None:
|
||||
continue
|
||||
|
||||
del state_dict[fr]
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
return context_k, context_v
|
||||
|
||||
def get_kv(context, hypernetwork):
|
||||
if hypernetwork is None:
|
||||
return context, context
|
||||
else:
|
||||
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
|
||||
|
||||
# This might need updating as the optimisedSD code changes
|
||||
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
|
||||
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
|
||||
def new_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
# default context
|
||||
context = context if context is not None else x() if inspect.isfunction(x) else x
|
||||
# hypernetwork!
|
||||
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = self.att_step
|
||||
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range (0, limit, att_step):
|
||||
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i+att_step,:,:] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
def load_hypernetwork(path: str):
|
||||
|
||||
state_dict = torch.load(path, map_location='cpu')
|
||||
|
||||
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
activation_func = state_dict.get('activation_func', None)
|
||||
weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
use_dropout = state_dict.get('use_dropout', False)
|
||||
activate_output = state_dict.get('activate_output', True)
|
||||
last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
|
||||
# print(f"layer_structure: {layer_structure}")
|
||||
# print(f"activation_func: {activation_func}")
|
||||
# print(f"weight_init: {weight_init}")
|
||||
# print(f"add_layer_norm: {add_layer_norm}")
|
||||
# print(f"use_dropout: {use_dropout}")
|
||||
# print(f"activate_output: {activate_output}")
|
||||
# print(f"last_layer_dropout: {last_layer_dropout}")
|
||||
|
||||
layers = {}
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
layers[size] = (
|
||||
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||
)
|
||||
print(f"hypernetwork loaded")
|
||||
return layers
|
||||
|
||||
|
||||
|
||||
# overriding of original function
|
||||
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
|
||||
# hijacks the cross attention forward function to add hyper network support
|
||||
def hijack_cross_attention():
|
||||
print("hypernetwork functionality added to cross attention")
|
||||
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
|
||||
# there was a cop on board
|
||||
def unhijack_cross_attention_forward():
|
||||
print("hypernetwork functionality removed from cross attention")
|
||||
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
|
||||
|
||||
hijack_cross_attention()
|
@ -28,6 +28,9 @@ from gfpgan import GFPGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from . import hypernetwork
|
||||
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
|
||||
|
||||
from threading import Lock
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@ -57,12 +60,15 @@ def thread_init(device):
|
||||
|
||||
thread_data.ckpt_file = None
|
||||
thread_data.vae_file = None
|
||||
thread_data.hypernetwork_file = None
|
||||
thread_data.gfpgan_file = None
|
||||
thread_data.real_esrgan_file = None
|
||||
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
thread_data.hypernetwork = None
|
||||
thread_data.hypernetwork_strength = 1
|
||||
thread_data.model_gfpgan = None
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
@ -72,6 +78,8 @@ def thread_init(device):
|
||||
thread_data.device_name = None
|
||||
thread_data.unet_bs = 1
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.sampler_plms = None
|
||||
thread_data.sampler_ddim = None
|
||||
|
||||
thread_data.turbo = False
|
||||
thread_data.force_full_precision = False
|
||||
@ -433,6 +441,49 @@ def reload_model():
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
def is_hypernetwork_reload_necessary(req: Request):
|
||||
needs_model_reload = False
|
||||
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
|
||||
thread_data.hypernetwork_file = req.use_hypernetwork_model
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
||||
def load_hypernetwork():
|
||||
if thread_data.hypernetwork_file is not None:
|
||||
try:
|
||||
loaded = False
|
||||
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
|
||||
if os.path.exists(thread_data.hypernetwork_file + model_extension):
|
||||
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
|
||||
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
|
||||
loaded = True
|
||||
break
|
||||
|
||||
if not loaded:
|
||||
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
|
||||
thread_data.hypernetwork_file = None
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
|
||||
thread_data.hypernetwork_file = None
|
||||
|
||||
def unload_hypernetwork():
|
||||
if thread_data.hypernetwork is not None:
|
||||
print('Unloading hypernetwork...')
|
||||
if thread_data.device != 'cpu':
|
||||
for i in thread_data.hypernetwork:
|
||||
thread_data.hypernetwork[i][0].to('cpu')
|
||||
thread_data.hypernetwork[i][1].to('cpu')
|
||||
del thread_data.hypernetwork
|
||||
thread_data.hypernetwork = None
|
||||
|
||||
gc()
|
||||
|
||||
def reload_hypernetwork():
|
||||
unload_hypernetwork()
|
||||
load_hypernetwork()
|
||||
|
||||
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
||||
@ -509,6 +560,7 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
|
||||
res = Response()
|
||||
res.request = req
|
||||
res.images = []
|
||||
thread_data.hypernetwork_strength = req.hypernetwork_strength
|
||||
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
@ -751,6 +803,8 @@ Sampler: {req.sampler}
|
||||
Negative Prompt: {req.negative_prompt}
|
||||
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
VAE model: {req.use_vae_model}
|
||||
Hypernetwork Model: {req.use_hypernetwork_model}
|
||||
Hypernetwork Strength: {req.hypernetwork_strength}
|
||||
'''
|
||||
try:
|
||||
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
||||
|
@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = None
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
@ -177,28 +179,35 @@ current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
current_vae_path = None
|
||||
current_hypernetwork_path = None
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
default_hypernetwork_to_load = None
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||
if ckpt_file_path == None:
|
||||
ckpt_file_path = default_model_to_load
|
||||
if vae_file_path == None:
|
||||
vae_file_path = default_vae_to_load
|
||||
if hypernetwork_file_path == None:
|
||||
hypernetwork_file_path = default_hypernetwork_to_load
|
||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||
return
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from . import runtime
|
||||
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||
runtime.thread_data.vae_file = vae_file_path
|
||||
runtime.load_model_ckpt()
|
||||
runtime.load_hypernetwork()
|
||||
current_model_path = ckpt_file_path
|
||||
current_vae_path = vae_file_path
|
||||
current_hypernetwork_path = hypernetwork_file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
@ -240,7 +249,7 @@ def thread_get_next_task():
|
||||
manager_lock.release()
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path
|
||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||
from . import runtime
|
||||
try:
|
||||
runtime.thread_init(device)
|
||||
@ -285,6 +294,10 @@ def thread_render(device):
|
||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
if runtime.is_hypernetwork_reload_necessary(task.request):
|
||||
runtime.reload_hypernetwork()
|
||||
current_hypernetwork_path = task.request.use_hypernetwork_model
|
||||
|
||||
if runtime.is_model_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime.reload_model()
|
||||
@ -504,6 +517,8 @@ def render(req : ImageRequest):
|
||||
r.use_face_correction = req.use_face_correction
|
||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||
r.use_vae_model = req.use_vae_model
|
||||
r.use_hypernetwork_model = req.use_hypernetwork_model
|
||||
r.hypernetwork_strength = req.hypernetwork_strength
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
r.output_format = req.output_format
|
||||
r.output_quality = req.output_quality
|
||||
|
20
ui/server.py
20
ui/server.py
@ -26,6 +26,7 @@ UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user
|
||||
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
@ -193,6 +194,12 @@ def resolve_vae_to_use(model_name:str=None):
|
||||
except:
|
||||
return None
|
||||
|
||||
def resolve_hypernetwork_to_use(model_name:str=None):
|
||||
try:
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
||||
except:
|
||||
return None
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
@ -253,10 +260,12 @@ def getModels():
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
'vae': '',
|
||||
'hypernetwork': '',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
'vae': [],
|
||||
'hypernetwork': [],
|
||||
},
|
||||
}
|
||||
|
||||
@ -288,7 +297,7 @@ def getModels():
|
||||
# custom models
|
||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
||||
|
||||
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
if os.path.exists(custom_weight_path):
|
||||
@ -363,16 +372,19 @@ def ping(session_id:str=None):
|
||||
response['devices'] = task_manager.get_devices()
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
def save_model_to_config(ckpt_model_name, vae_model_name):
|
||||
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
|
||||
config = getConfig()
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = ckpt_model_name
|
||||
config['model']['vae'] = vae_model_name
|
||||
config['model']['hypernetwork'] = hypernetwork_model_name
|
||||
|
||||
if vae_model_name is None or vae_model_name == "":
|
||||
del config['model']['vae']
|
||||
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
||||
del config['model']['hypernetwork']
|
||||
|
||||
setConfig(config)
|
||||
|
||||
@ -388,9 +400,10 @@ def update_render_devices_in_config(config, render_devices):
|
||||
@app.post('/render')
|
||||
def render(req : task_manager.ImageRequest):
|
||||
try:
|
||||
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model)
|
||||
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
||||
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
|
||||
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
@ -469,6 +482,7 @@ getModels()
|
||||
# Start the task_manager
|
||||
task_manager.default_model_to_load = resolve_ckpt_to_use()
|
||||
task_manager.default_vae_to_load = resolve_vae_to_use()
|
||||
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
|
||||
|
||||
def update_render_threads():
|
||||
config = getConfig()
|
||||
|
Loading…
Reference in New Issue
Block a user