diff --git a/README.md b/README.md
index 2038c15b..01509a46 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,8 @@
[](https://discord.com/invite/u9yhsFmEkB) (for support, and development discussion) | [Troubleshooting guide for common problems](Troubleshooting.md)
+New! Experimental support for Stable Diffusion 2.0 is available in beta!
+
----
## Step 1: Download the installer
@@ -28,7 +30,9 @@ The installer will take care of whatever is needed. A friendly [Discord communit
- **No Dependencies or Technical Knowledge Required**: 1-click install for Windows 10/11 and Linux. *No dependencies*, no need for WSL or Docker or Conda or technical setup. Just download and run!
- **Clutter-free UI**: a friendly and simple UI, while providing a lot of powerful features
- Supports "*Text to Image*" and "*Image to Image*"
+- **Stable Diffusion 2.0 support (experimental)** - available in beta channel
- **Custom Models**: Use your own `.ckpt` file, by placing it inside the `models/stable-diffusion` folder!
+- **Auto scan for malicious models** - uses picklescan to prevent malicious models
- **Live Preview**: See the image as the AI is drawing it
- **Task Queue**: Queue up all your ideas, without waiting for the current task to finish
- **In-Painting**: Specify areas of your image to paint into
diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat
index 99fb74bc..2daa39e2 100644
--- a/scripts/on_sd_start.bat
+++ b/scripts/on_sd_start.bat
@@ -201,8 +201,10 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae"
+if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
+echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh
index 177e4f73..f8f3d560 100755
--- a/scripts/on_sd_start.sh
+++ b/scripts/on_sd_start.sh
@@ -161,8 +161,10 @@ fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae"
+mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
+echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
diff --git a/ui/index.html b/ui/index.html
index 50869fc7..d3fb6da3 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -131,6 +131,12 @@
Click to learn more about VAEs
+
Sampler:
plms
diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js
index 2d1543cc..f503779a 100644
--- a/ui/media/js/auto-save.js
+++ b/ui/media/js/auto-save.js
@@ -14,12 +14,14 @@ const SETTINGS_IDS_LIST = [
"num_outputs_parallel",
"stable_diffusion_model",
"vae_model",
+ "hypernetwork_model",
"sampler",
"width",
"height",
"num_inference_steps",
"guidance_scale",
"prompt_strength",
+ "hypernetwork_strength",
"output_format",
"output_quality",
"negative_prompt",
diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js
index 3efa095c..d3c20328 100644
--- a/ui/media/js/engine.js
+++ b/ui/media/js/engine.js
@@ -806,6 +806,11 @@
continue
}
if (key in TASK_OPTIONAL) {
+ if (typeof this._reqBody[key] == "undefined") {
+ delete this._reqBody[key]
+ console.warn(`reqBody[${key}] was set to undefined. Removing optional key without value...`)
+ continue
+ }
if (typeof this._reqBody[key] !== TASK_OPTIONAL[key]) {
throw new Error(`${key} need to be of type ${TASK_OPTIONAL[key]} but ${typeof this._reqBody[key]} was found.`)
}
diff --git a/ui/media/js/main.js b/ui/media/js/main.js
index aa0bbba3..de5c988b 100644
--- a/ui/media/js/main.js
+++ b/ui/media/js/main.js
@@ -35,6 +35,9 @@ let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model")
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
let vaeModelField = document.querySelector('#vae_model')
+let hypernetworkModelField = document.querySelector('#hypernetwork_model')
+let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider')
+let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength')
let outputFormatField = document.querySelector('#output_format')
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel")
@@ -654,6 +657,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
}
}
+
function onTaskStart(task) {
if (!task.isProcessing || task.batchesDone >= task.batchCount) {
return
@@ -705,13 +709,13 @@ function onTaskStart(task) {
})
let instance = eventInfo.instance
if (!instance) {
- const factory = PLUGINS.OUTPUTS_FORMATS.get(newTaskReqBody.output_format)
+ const factory = PLUGINS.OUTPUTS_FORMATS.get(eventInfo.reqBody?.output_format || newTaskReqBody.output_format)
if (factory) {
- instance = factory(newTaskReqBody)
+ instance = factory(eventInfo.reqBody || newTaskReqBody)
}
if (!instance) {
- console.error(`${factory ? "Factory " + String(factory) : 'No factory defined'} for output format ${newTaskReqBody.output_format}. Instance is ${instance || 'undefined'}. Using default renderer.`)
- instance = new SD.RenderTask(newTaskReqBody)
+ console.error(`${factory ? "Factory " + String(factory) : 'No factory defined'} for output format ${eventInfo.reqBody?.output_format || newTaskReqBody.output_format}. Instance is ${instance || 'undefined'}. Using default renderer.`)
+ instance = new SD.RenderTask(eventInfo.reqBody || newTaskReqBody)
}
}
@@ -750,6 +754,10 @@ function createTask(task) {
if (task.reqBody.use_upscale) {
taskConfig += `, Upscale: ${task.reqBody.use_upscale}`
}
+ if (task.reqBody.use_hypernetwork_model) {
+ taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}`
+ taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}`
+ }
let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}`
@@ -1105,6 +1113,27 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength()
+/********************* Hypernetwork Strength **********************/
+function updateHypernetworkStrength() {
+ hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100
+ hypernetworkStrengthField.dispatchEvent(new Event("change"))
+}
+
+function updateHypernetworkStrengthSlider() {
+ if (hypernetworkStrengthField.value < 0) {
+ hypernetworkStrengthField.value = 0
+ } else if (hypernetworkStrengthField.value > 0.99) {
+ hypernetworkStrengthField.value = 0.99
+ }
+
+ hypernetworkStrengthSlider.value = hypernetworkStrengthField.value * 100
+ hypernetworkStrengthSlider.dispatchEvent(new Event("change"))
+}
+
+hypernetworkStrengthSlider.addEventListener('input', updateHypernetworkStrength)
+hypernetworkStrengthField.addEventListener('input', updateHypernetworkStrengthSlider)
+updateHypernetworkStrength()
+
/********************* JPEG Quality **********************/
function updateOutputQuality() {
outputQualityField.value = 0 | outputQualitySlider.value
@@ -1138,12 +1167,14 @@ async function getModels() {
try {
const sd_model_setting_key = "stable_diffusion_model"
const vae_model_setting_key = "vae_model"
+ const hypernetwork_model_key = "hypernetwork_model"
const selectedSDModel = SETTINGS[sd_model_setting_key].value
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
+ const selectedHypernetworkModel = SETTINGS[hypernetwork_model_key].value
const models = await SD.getModels()
const modelsOptions = models['options']
- if ( "scan-error" in models) {
+ if ("scan-error" in models) {
// let previewPane = document.getElementById('tab-content-wrapper')
let previewPane = document.getElementById('preview')
previewPane.style.background="red"
@@ -1154,7 +1185,10 @@ async function getModels() {
const stableDiffusionOptions = modelsOptions['stable-diffusion']
const vaeOptions = modelsOptions['vae']
+ const hypernetworkOptions = modelOptions['hypernetwork']
+
vaeOptions.unshift('') // add a None option
+ hypernetworkOptions.unshift('') // add a None option
function createModelOptions(modelField, selectedModel) {
return function(modelName) {
@@ -1172,6 +1206,7 @@ async function getModels() {
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
+ hypernetworkOptions.forEach(createModelOptions(hypernetworkModelField, selectedHypernetworkModel))
// TODO: set default for model here too
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py
index 18b0a01b..a2abe294 100644
--- a/ui/sd_internal/__init__.py
+++ b/ui/sd_internal/__init__.py
@@ -23,6 +23,8 @@ class Request:
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
+ use_hypernetwork_model: str = None
+ hypernetwork_strength: float = 1
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
@@ -38,6 +40,7 @@ class Request:
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
+ "hypernetwork_strengtgh": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
@@ -47,6 +50,8 @@ class Request:
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
+ "use_hypernetwork_model": self.use_hypernetwork_model,
+ "hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
@@ -70,6 +75,8 @@ class Request:
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
+ use_hypernetwork_model: {self.use_hypernetwork_model}
+ hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}
diff --git a/ui/sd_internal/hypernetwork.py b/ui/sd_internal/hypernetwork.py
new file mode 100644
index 00000000..979a74f3
--- /dev/null
+++ b/ui/sd_internal/hypernetwork.py
@@ -0,0 +1,198 @@
+# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
+# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
+
+import inspect
+import torch
+import optimizedSD.splitAttention
+from . import runtime
+from einops import rearrange
+
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
+loaded_hypernetwork = None
+
+class HypernetworkModule(torch.nn.Module):
+ multiplier = 0.5
+ activation_dict = {
+ "linear": torch.nn.Identity,
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
+ }
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
+ super().__init__()
+
+ assert layer_structure is not None, "layer_structure must not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+
+ # Add dropout except last layer
+ if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ self.linear = torch.nn.Sequential(*linears)
+
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
+
+ self.to(runtime.thread_data.device)
+
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
+
+ def forward(self, x: torch.Tensor):
+ return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = hypernetwork.get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+def get_kv(context, hypernetwork):
+ if hypernetwork is None:
+ return context, context
+ else:
+ return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
+
+# This might need updating as the optimisedSD code changes
+# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
+# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
+def new_cross_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ # default context
+ context = context if context is not None else x() if inspect.isfunction(x) else x
+ # hypernetwork!
+ context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+
+ limit = k.shape[0]
+ att_step = self.att_step
+ q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
+
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range (0, limit, att_step):
+
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+
+ sim_buffer = sim_buffer.softmax(dim=-1)
+
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i+att_step,:,:] = sim_buffer
+
+ del sim_buffer
+ sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
+
+
+def load_hypernetwork(path: str):
+
+ state_dict = torch.load(path, map_location='cpu')
+
+ layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ activation_func = state_dict.get('activation_func', None)
+ weight_init = state_dict.get('weight_initialization', 'Normal')
+ add_layer_norm = state_dict.get('is_layer_norm', False)
+ use_dropout = state_dict.get('use_dropout', False)
+ activate_output = state_dict.get('activate_output', True)
+ last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ # this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
+ # print(f"layer_structure: {layer_structure}")
+ # print(f"activation_func: {activation_func}")
+ # print(f"weight_init: {weight_init}")
+ # print(f"add_layer_norm: {add_layer_norm}")
+ # print(f"use_dropout: {use_dropout}")
+ # print(f"activate_output: {activate_output}")
+ # print(f"last_layer_dropout: {last_layer_dropout}")
+
+ layers = {}
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ layers[size] = (
+ HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
+ use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
+ HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
+ use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
+ )
+ print(f"hypernetwork loaded")
+ return layers
+
+
+
+# overriding of original function
+old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
+# hijacks the cross attention forward function to add hyper network support
+def hijack_cross_attention():
+ print("hypernetwork functionality added to cross attention")
+ optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
+# there was a cop on board
+def unhijack_cross_attention_forward():
+ print("hypernetwork functionality removed from cross attention")
+ optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
+
+hijack_cross_attention()
\ No newline at end of file
diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py
index 0f491865..2e0f8ce0 100644
--- a/ui/sd_internal/runtime.py
+++ b/ui/sd_internal/runtime.py
@@ -28,6 +28,9 @@ from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
+from . import hypernetwork
+from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
+
from threading import Lock
from safetensors.torch import load_file
@@ -57,12 +60,15 @@ def thread_init(device):
thread_data.ckpt_file = None
thread_data.vae_file = None
+ thread_data.hypernetwork_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
+ thread_data.hypernetwork = None
+ thread_data.hypernetwork_strength = 1
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
@@ -72,6 +78,8 @@ def thread_init(device):
thread_data.device_name = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
+ thread_data.sampler_plms = None
+ thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.force_full_precision = False
@@ -433,6 +441,49 @@ def reload_model():
unload_filters()
load_model_ckpt()
+def is_hypernetwork_reload_necessary(req: Request):
+ needs_model_reload = False
+ if thread_data.hypernetwork_file != req.use_hypernetwork_model:
+ thread_data.hypernetwork_file = req.use_hypernetwork_model
+ needs_model_reload = True
+
+ return needs_model_reload
+
+def load_hypernetwork():
+ if thread_data.hypernetwork_file is not None:
+ try:
+ loaded = False
+ for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
+ if os.path.exists(thread_data.hypernetwork_file + model_extension):
+ print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
+ thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
+ loaded = True
+ break
+
+ if not loaded:
+ print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
+ thread_data.hypernetwork_file = None
+ except:
+ print(traceback.format_exc())
+ print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
+ thread_data.hypernetwork_file = None
+
+def unload_hypernetwork():
+ if thread_data.hypernetwork is not None:
+ print('Unloading hypernetwork...')
+ if thread_data.device != 'cpu':
+ for i in thread_data.hypernetwork:
+ thread_data.hypernetwork[i][0].to('cpu')
+ thread_data.hypernetwork[i][1].to('cpu')
+ del thread_data.hypernetwork
+ thread_data.hypernetwork = None
+
+ gc()
+
+def reload_hypernetwork():
+ unload_hypernetwork()
+ load_hypernetwork()
+
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
return do_mk_img(req, data_queue, task_temp_images, step_callback)
@@ -509,6 +560,7 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
res = Response()
res.request = req
res.images = []
+ thread_data.hypernetwork_strength = req.hypernetwork_strength
thread_data.temp_images.clear()
@@ -751,6 +803,8 @@ Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
VAE model: {req.use_vae_model}
+Hypernetwork Model: {req.use_hypernetwork_model}
+Hypernetwork Strength: {req.hypernetwork_strength}
'''
try:
with open(meta_out_path, 'w', encoding='utf-8') as f:
diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py
index d21f2877..41fc00f6 100644
--- a/ui/sd_internal/task_manager.py
+++ b/ui/sd_internal/task_manager.py
@@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
+ use_hypernetwork_model: str = None
+ hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
@@ -177,28 +179,35 @@ current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
current_vae_path = None
+current_hypernetwork_path = None
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
default_vae_to_load = None
+default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
-def preload_model(ckpt_file_path=None, vae_file_path=None):
- global current_state, current_state_error, current_model_path, current_vae_path
+def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
+ global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
+ if hypernetwork_file_path == None:
+ hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
+ runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
+ runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
+ current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
@@ -240,7 +249,7 @@ def thread_get_next_task():
manager_lock.release()
def thread_render(device):
- global current_state, current_state_error, current_model_path, current_vae_path
+ global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
from . import runtime
try:
runtime.thread_init(device)
@@ -285,6 +294,10 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
+ if runtime.is_hypernetwork_reload_necessary(task.request):
+ runtime.reload_hypernetwork()
+ current_hypernetwork_path = task.request.use_hypernetwork_model
+
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
@@ -504,6 +517,8 @@ def render(req : ImageRequest):
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
+ r.use_hypernetwork_model = req.use_hypernetwork_model
+ r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality
diff --git a/ui/server.py b/ui/server.py
index c7760889..804994a2 100644
--- a/ui/server.py
+++ b/ui/server.py
@@ -26,6 +26,7 @@ UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
+HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
@@ -193,6 +194,12 @@ def resolve_vae_to_use(model_name:str=None):
except:
return None
+def resolve_hypernetwork_to_use(model_name:str=None):
+ try:
+ return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
+ except:
+ return None
+
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@@ -253,10 +260,12 @@ def getModels():
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
+ 'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
+ 'hypernetwork': [],
},
}
@@ -288,7 +297,7 @@ def getModels():
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
-
+ listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
@@ -363,16 +372,19 @@ def ping(session_id:str=None):
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
-def save_model_to_config(ckpt_model_name, vae_model_name):
+def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
+ config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
+ if hypernetwork_model_name is None or hypernetwork_model_name == "":
+ del config['model']['hypernetwork']
setConfig(config)
@@ -388,9 +400,10 @@ def update_render_devices_in_config(config, render_devices):
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
- save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model)
+ save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
+ req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
@@ -469,6 +482,7 @@ getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()
+task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
def update_render_threads():
config = getConfig()