Hypernetwork support (#619)

* Update README.md

* Update README.md

* Make on_sd_start.sh executable

* Merge pull request #542 from patriceac/patch-1

Fix restoration of model and VAE

* Merge pull request #541 from patriceac/patch-2

Fix restoration of parallel output setting

* Hypernetwork support

Adds support for hypernetworks. Hypernetworks are stored in /models/hypernetworks

* forgot to remove unused code

Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
Guillaume Mercier 2022-12-07 00:54:16 -05:00 committed by GitHub
parent 1283c6483d
commit cbe91251ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 346 additions and 7 deletions

View File

@ -3,6 +3,8 @@
[![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB) (for support, and development discussion) | [Troubleshooting guide for common problems](Troubleshooting.md)
New! Experimental support for Stable Diffusion 2.0 is available in beta!
----
## Step 1: Download the installer
@ -28,7 +30,9 @@ The installer will take care of whatever is needed. A friendly [Discord communit
- **No Dependencies or Technical Knowledge Required**: 1-click install for Windows 10/11 and Linux. *No dependencies*, no need for WSL or Docker or Conda or technical setup. Just download and run!
- **Clutter-free UI**: a friendly and simple UI, while providing a lot of powerful features
- Supports "*Text to Image*" and "*Image to Image*"
- **Stable Diffusion 2.0 support (experimental)** - available in beta channel
- **Custom Models**: Use your own `.ckpt` file, by placing it inside the `models/stable-diffusion` folder!
- **Auto scan for malicious models** - uses picklescan to prevent malicious models
- **Live Preview**: See the image as the AI is drawing it
- **Task Queue**: Queue up all your ideas, without waiting for the current task to finish
- **In-Painting**: Specify areas of your image to paint into

View File

@ -201,8 +201,10 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae"
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (

View File

@ -161,8 +161,10 @@ fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae"
mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"`

View File

@ -131,6 +131,12 @@
</select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">Click to learn more about VAEs</span></i></a>
</td></tr>
<tr class="pl-5"><td><label for="hypernetwork_model">Hypernetwork:</i></label></td><td>
<select id="hypernetwork_model" name="hypernetwork_model">
<!-- <option value="" selected>None</option> -->
</select>
</td></tr>
<tr id="hypernetwork_strength_container" class="pl-5"><td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td><td> <input id="hypernetwork_strength_slider" name="hypernetwork_strength_slider" class="editor-slider" value="100" type="range" min="0" max="100"> <input id="hypernetwork_strength" name="hypernetwork_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)"><br/></td></tr></span>
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td>
<select id="sampler" name="sampler">
<option value="plms">plms</option>

View File

@ -14,12 +14,14 @@ const SETTINGS_IDS_LIST = [
"num_outputs_parallel",
"stable_diffusion_model",
"vae_model",
"hypernetwork_model",
"sampler",
"width",
"height",
"num_inference_steps",
"guidance_scale",
"prompt_strength",
"hypernetwork_strength",
"output_format",
"output_quality",
"negative_prompt",

View File

@ -35,6 +35,9 @@ let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model")
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
let vaeModelField = document.querySelector('#vae_model')
let hypernetworkModelField = document.querySelector('#hypernetwork_model')
let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider')
let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength')
let outputFormatField = document.querySelector('#output_format')
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel")
@ -654,6 +657,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
}
}
function onTaskStart(task) {
if (!task.isProcessing || task.batchesDone >= task.batchCount) {
return
@ -750,6 +754,10 @@ function createTask(task) {
if (task.reqBody.use_upscale) {
taskConfig += `, <b>Upscale:</b> ${task.reqBody.use_upscale}`
}
if (task.reqBody.use_hypernetwork_model) {
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
}
let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}`
@ -1105,6 +1113,27 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength()
/********************* Hypernetwork Strength **********************/
function updateHypernetworkStrength() {
hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100
hypernetworkStrengthField.dispatchEvent(new Event("change"))
}
function updateHypernetworkStrengthSlider() {
if (hypernetworkStrengthField.value < 0) {
hypernetworkStrengthField.value = 0
} else if (hypernetworkStrengthField.value > 0.99) {
hypernetworkStrengthField.value = 0.99
}
hypernetworkStrengthSlider.value = hypernetworkStrengthField.value * 100
hypernetworkStrengthSlider.dispatchEvent(new Event("change"))
}
hypernetworkStrengthSlider.addEventListener('input', updateHypernetworkStrength)
hypernetworkStrengthField.addEventListener('input', updateHypernetworkStrengthSlider)
updateHypernetworkStrength()
/********************* JPEG Quality **********************/
function updateOutputQuality() {
outputQualityField.value = 0 | outputQualitySlider.value
@ -1138,8 +1167,10 @@ async function getModels() {
try {
const sd_model_setting_key = "stable_diffusion_model"
const vae_model_setting_key = "vae_model"
const hypernetwork_model_key = "hypernetwork_model"
const selectedSDModel = SETTINGS[sd_model_setting_key].value
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
const selectedHypernetworkModel = SETTINGS[hypernetwork_model_key].value
const models = await SD.getModels()
const modelsOptions = models['options']
@ -1154,7 +1185,10 @@ async function getModels() {
const stableDiffusionOptions = modelsOptions['stable-diffusion']
const vaeOptions = modelsOptions['vae']
const hypernetworkOptions = modelOptions['hypernetwork']
vaeOptions.unshift('') // add a None option
hypernetworkOptions.unshift('') // add a None option
function createModelOptions(modelField, selectedModel) {
return function(modelName) {
@ -1172,6 +1206,7 @@ async function getModels() {
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
hypernetworkOptions.forEach(createModelOptions(hypernetworkModelField, selectedHypernetworkModel))
// TODO: set default for model here too
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]

View File

@ -23,6 +23,8 @@ class Request:
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = 1
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
@ -38,6 +40,7 @@ class Request:
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"hypernetwork_strengtgh": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
@ -47,6 +50,8 @@ class Request:
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"use_hypernetwork_model": self.use_hypernetwork_model,
"hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
@ -70,6 +75,8 @@ class Request:
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
use_hypernetwork_model: {self.use_hypernetwork_model}
hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}

View File

@ -0,0 +1,198 @@
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
import inspect
import torch
import optimizedSD.splitAttention
from . import runtime
from einops import rearrange
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
loaded_hypernetwork = None
class HypernetworkModule(torch.nn.Module):
multiplier = 0.5
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
self.fix_old_state_dict(state_dict)
self.load_state_dict(state_dict)
self.to(runtime.thread_data.device)
def fix_old_state_dict(self, state_dict):
changes = {
'linear1.bias': 'linear.0.bias',
'linear1.weight': 'linear.0.weight',
'linear2.bias': 'linear.1.bias',
'linear2.weight': 'linear.1.weight',
}
for fr, to in changes.items():
x = state_dict.get(fr, None)
if x is None:
continue
del state_dict[fr]
state_dict[to] = x
def forward(self, x: torch.Tensor):
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def get_kv(context, hypernetwork):
if hypernetwork is None:
return context, context
else:
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
# This might need updating as the optimisedSD code changes
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
def new_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
# default context
context = context if context is not None else x() if inspect.isfunction(x) else x
# hypernetwork!
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = self.att_step
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range (0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i+att_step,:,:] = sim_buffer
del sim_buffer
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
def load_hypernetwork(path: str):
state_dict = torch.load(path, map_location='cpu')
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
activation_func = state_dict.get('activation_func', None)
weight_init = state_dict.get('weight_initialization', 'Normal')
add_layer_norm = state_dict.get('is_layer_norm', False)
use_dropout = state_dict.get('use_dropout', False)
activate_output = state_dict.get('activate_output', True)
last_layer_dropout = state_dict.get('last_layer_dropout', False)
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
# print(f"layer_structure: {layer_structure}")
# print(f"activation_func: {activation_func}")
# print(f"weight_init: {weight_init}")
# print(f"add_layer_norm: {add_layer_norm}")
# print(f"use_dropout: {use_dropout}")
# print(f"activate_output: {activate_output}")
# print(f"last_layer_dropout: {last_layer_dropout}")
layers = {}
for size, sd in state_dict.items():
if type(size) == int:
layers[size] = (
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
)
print(f"hypernetwork loaded")
return layers
# overriding of original function
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
# hijacks the cross attention forward function to add hyper network support
def hijack_cross_attention():
print("hypernetwork functionality added to cross attention")
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
# there was a cop on board
def unhijack_cross_attention_forward():
print("hypernetwork functionality removed from cross attention")
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
hijack_cross_attention()

View File

@ -28,6 +28,9 @@ from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from . import hypernetwork
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
from threading import Lock
from safetensors.torch import load_file
@ -57,12 +60,15 @@ def thread_init(device):
thread_data.ckpt_file = None
thread_data.vae_file = None
thread_data.hypernetwork_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.hypernetwork = None
thread_data.hypernetwork_strength = 1
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
@ -72,6 +78,8 @@ def thread_init(device):
thread_data.device_name = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.force_full_precision = False
@ -433,6 +441,49 @@ def reload_model():
unload_filters()
load_model_ckpt()
def is_hypernetwork_reload_necessary(req: Request):
needs_model_reload = False
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
thread_data.hypernetwork_file = req.use_hypernetwork_model
needs_model_reload = True
return needs_model_reload
def load_hypernetwork():
if thread_data.hypernetwork_file is not None:
try:
loaded = False
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
if os.path.exists(thread_data.hypernetwork_file + model_extension):
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
loaded = True
break
if not loaded:
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
thread_data.hypernetwork_file = None
except:
print(traceback.format_exc())
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
thread_data.hypernetwork_file = None
def unload_hypernetwork():
if thread_data.hypernetwork is not None:
print('Unloading hypernetwork...')
if thread_data.device != 'cpu':
for i in thread_data.hypernetwork:
thread_data.hypernetwork[i][0].to('cpu')
thread_data.hypernetwork[i][1].to('cpu')
del thread_data.hypernetwork
thread_data.hypernetwork = None
gc()
def reload_hypernetwork():
unload_hypernetwork()
load_hypernetwork()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
return do_mk_img(req, data_queue, task_temp_images, step_callback)
@ -509,6 +560,7 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
res = Response()
res.request = req
res.images = []
thread_data.hypernetwork_strength = req.hypernetwork_strength
thread_data.temp_images.clear()
@ -751,6 +803,8 @@ Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
VAE model: {req.use_vae_model}
Hypernetwork Model: {req.use_hypernetwork_model}
Hypernetwork Strength: {req.hypernetwork_strength}
'''
try:
with open(meta_out_path, 'w', encoding='utf-8') as f:

View File

@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
@ -177,28 +179,35 @@ current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(ckpt_file_path=None, vae_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
@ -240,7 +249,7 @@ def thread_get_next_task():
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
from . import runtime
try:
runtime.thread_init(device)
@ -285,6 +294,10 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
if runtime.is_hypernetwork_reload_necessary(task.request):
runtime.reload_hypernetwork()
current_hypernetwork_path = task.request.use_hypernetwork_model
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
@ -504,6 +517,8 @@ def render(req : ImageRequest):
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality

View File

@ -26,6 +26,7 @@ UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
@ -193,6 +194,12 @@ def resolve_vae_to_use(model_name:str=None):
except:
return None
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@ -253,10 +260,12 @@ def getModels():
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
@ -288,7 +297,7 @@ def getModels():
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
@ -363,16 +372,19 @@ def ping(session_id:str=None):
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(ckpt_model_name, vae_model_name):
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
setConfig(config)
@ -388,9 +400,10 @@ def update_render_devices_in_config(config, render_devices):
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model)
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
@ -469,6 +482,7 @@ getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
def update_render_threads():
config = getConfig()