mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-18 03:11:10 +01:00
Hypernetwork support (#619)
* Update README.md * Update README.md * Make on_sd_start.sh executable * Merge pull request #542 from patriceac/patch-1 Fix restoration of model and VAE * Merge pull request #541 from patriceac/patch-2 Fix restoration of parallel output setting * Hypernetwork support Adds support for hypernetworks. Hypernetworks are stored in /models/hypernetworks * forgot to remove unused code Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
parent
1283c6483d
commit
cbe91251ac
@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
[![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB) (for support, and development discussion) | [Troubleshooting guide for common problems](Troubleshooting.md)
|
[![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB) (for support, and development discussion) | [Troubleshooting guide for common problems](Troubleshooting.md)
|
||||||
|
|
||||||
|
New! Experimental support for Stable Diffusion 2.0 is available in beta!
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
## Step 1: Download the installer
|
## Step 1: Download the installer
|
||||||
@ -28,7 +30,9 @@ The installer will take care of whatever is needed. A friendly [Discord communit
|
|||||||
- **No Dependencies or Technical Knowledge Required**: 1-click install for Windows 10/11 and Linux. *No dependencies*, no need for WSL or Docker or Conda or technical setup. Just download and run!
|
- **No Dependencies or Technical Knowledge Required**: 1-click install for Windows 10/11 and Linux. *No dependencies*, no need for WSL or Docker or Conda or technical setup. Just download and run!
|
||||||
- **Clutter-free UI**: a friendly and simple UI, while providing a lot of powerful features
|
- **Clutter-free UI**: a friendly and simple UI, while providing a lot of powerful features
|
||||||
- Supports "*Text to Image*" and "*Image to Image*"
|
- Supports "*Text to Image*" and "*Image to Image*"
|
||||||
|
- **Stable Diffusion 2.0 support (experimental)** - available in beta channel
|
||||||
- **Custom Models**: Use your own `.ckpt` file, by placing it inside the `models/stable-diffusion` folder!
|
- **Custom Models**: Use your own `.ckpt` file, by placing it inside the `models/stable-diffusion` folder!
|
||||||
|
- **Auto scan for malicious models** - uses picklescan to prevent malicious models
|
||||||
- **Live Preview**: See the image as the AI is drawing it
|
- **Live Preview**: See the image as the AI is drawing it
|
||||||
- **Task Queue**: Queue up all your ideas, without waiting for the current task to finish
|
- **Task Queue**: Queue up all your ideas, without waiting for the current task to finish
|
||||||
- **In-Painting**: Specify areas of your image to paint into
|
- **In-Painting**: Specify areas of your image to paint into
|
||||||
|
@ -201,8 +201,10 @@ call WHERE uvicorn > .tmp
|
|||||||
|
|
||||||
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
|
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
|
||||||
if not exist "..\models\vae" mkdir "..\models\vae"
|
if not exist "..\models\vae" mkdir "..\models\vae"
|
||||||
|
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
|
||||||
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
||||||
echo. > "..\models\vae\Put your VAE files here.txt"
|
echo. > "..\models\vae\Put your VAE files here.txt"
|
||||||
|
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
|
||||||
|
|
||||||
@if exist "sd-v1-4.ckpt" (
|
@if exist "sd-v1-4.ckpt" (
|
||||||
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
||||||
|
@ -161,8 +161,10 @@ fi
|
|||||||
|
|
||||||
mkdir -p "../models/stable-diffusion"
|
mkdir -p "../models/stable-diffusion"
|
||||||
mkdir -p "../models/vae"
|
mkdir -p "../models/vae"
|
||||||
|
mkdir -p "../models/hypernetwork"
|
||||||
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
|
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
|
||||||
echo "" > "../models/vae/Put your VAE files here.txt"
|
echo "" > "../models/vae/Put your VAE files here.txt"
|
||||||
|
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
|
||||||
|
|
||||||
if [ -f "sd-v1-4.ckpt" ]; then
|
if [ -f "sd-v1-4.ckpt" ]; then
|
||||||
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
||||||
|
@ -131,6 +131,12 @@
|
|||||||
</select>
|
</select>
|
||||||
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">Click to learn more about VAEs</span></i></a>
|
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">Click to learn more about VAEs</span></i></a>
|
||||||
</td></tr>
|
</td></tr>
|
||||||
|
<tr class="pl-5"><td><label for="hypernetwork_model">Hypernetwork:</i></label></td><td>
|
||||||
|
<select id="hypernetwork_model" name="hypernetwork_model">
|
||||||
|
<!-- <option value="" selected>None</option> -->
|
||||||
|
</select>
|
||||||
|
</td></tr>
|
||||||
|
<tr id="hypernetwork_strength_container" class="pl-5"><td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td><td> <input id="hypernetwork_strength_slider" name="hypernetwork_strength_slider" class="editor-slider" value="100" type="range" min="0" max="100"> <input id="hypernetwork_strength" name="hypernetwork_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)"><br/></td></tr></span>
|
||||||
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td>
|
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td>
|
||||||
<select id="sampler" name="sampler">
|
<select id="sampler" name="sampler">
|
||||||
<option value="plms">plms</option>
|
<option value="plms">plms</option>
|
||||||
|
@ -14,12 +14,14 @@ const SETTINGS_IDS_LIST = [
|
|||||||
"num_outputs_parallel",
|
"num_outputs_parallel",
|
||||||
"stable_diffusion_model",
|
"stable_diffusion_model",
|
||||||
"vae_model",
|
"vae_model",
|
||||||
|
"hypernetwork_model",
|
||||||
"sampler",
|
"sampler",
|
||||||
"width",
|
"width",
|
||||||
"height",
|
"height",
|
||||||
"num_inference_steps",
|
"num_inference_steps",
|
||||||
"guidance_scale",
|
"guidance_scale",
|
||||||
"prompt_strength",
|
"prompt_strength",
|
||||||
|
"hypernetwork_strength",
|
||||||
"output_format",
|
"output_format",
|
||||||
"output_quality",
|
"output_quality",
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
|
@ -35,6 +35,9 @@ let useUpscalingField = document.querySelector("#use_upscale")
|
|||||||
let upscaleModelField = document.querySelector("#upscale_model")
|
let upscaleModelField = document.querySelector("#upscale_model")
|
||||||
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
||||||
let vaeModelField = document.querySelector('#vae_model')
|
let vaeModelField = document.querySelector('#vae_model')
|
||||||
|
let hypernetworkModelField = document.querySelector('#hypernetwork_model')
|
||||||
|
let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider')
|
||||||
|
let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength')
|
||||||
let outputFormatField = document.querySelector('#output_format')
|
let outputFormatField = document.querySelector('#output_format')
|
||||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||||
@ -654,6 +657,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function onTaskStart(task) {
|
function onTaskStart(task) {
|
||||||
if (!task.isProcessing || task.batchesDone >= task.batchCount) {
|
if (!task.isProcessing || task.batchesDone >= task.batchCount) {
|
||||||
return
|
return
|
||||||
@ -750,6 +754,10 @@ function createTask(task) {
|
|||||||
if (task.reqBody.use_upscale) {
|
if (task.reqBody.use_upscale) {
|
||||||
taskConfig += `, <b>Upscale:</b> ${task.reqBody.use_upscale}`
|
taskConfig += `, <b>Upscale:</b> ${task.reqBody.use_upscale}`
|
||||||
}
|
}
|
||||||
|
if (task.reqBody.use_hypernetwork_model) {
|
||||||
|
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
|
||||||
|
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
|
||||||
|
}
|
||||||
|
|
||||||
let taskEntry = document.createElement('div')
|
let taskEntry = document.createElement('div')
|
||||||
taskEntry.id = `imageTaskContainer-${Date.now()}`
|
taskEntry.id = `imageTaskContainer-${Date.now()}`
|
||||||
@ -1105,6 +1113,27 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
|
|||||||
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
||||||
updatePromptStrength()
|
updatePromptStrength()
|
||||||
|
|
||||||
|
/********************* Hypernetwork Strength **********************/
|
||||||
|
function updateHypernetworkStrength() {
|
||||||
|
hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100
|
||||||
|
hypernetworkStrengthField.dispatchEvent(new Event("change"))
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateHypernetworkStrengthSlider() {
|
||||||
|
if (hypernetworkStrengthField.value < 0) {
|
||||||
|
hypernetworkStrengthField.value = 0
|
||||||
|
} else if (hypernetworkStrengthField.value > 0.99) {
|
||||||
|
hypernetworkStrengthField.value = 0.99
|
||||||
|
}
|
||||||
|
|
||||||
|
hypernetworkStrengthSlider.value = hypernetworkStrengthField.value * 100
|
||||||
|
hypernetworkStrengthSlider.dispatchEvent(new Event("change"))
|
||||||
|
}
|
||||||
|
|
||||||
|
hypernetworkStrengthSlider.addEventListener('input', updateHypernetworkStrength)
|
||||||
|
hypernetworkStrengthField.addEventListener('input', updateHypernetworkStrengthSlider)
|
||||||
|
updateHypernetworkStrength()
|
||||||
|
|
||||||
/********************* JPEG Quality **********************/
|
/********************* JPEG Quality **********************/
|
||||||
function updateOutputQuality() {
|
function updateOutputQuality() {
|
||||||
outputQualityField.value = 0 | outputQualitySlider.value
|
outputQualityField.value = 0 | outputQualitySlider.value
|
||||||
@ -1138,12 +1167,14 @@ async function getModels() {
|
|||||||
try {
|
try {
|
||||||
const sd_model_setting_key = "stable_diffusion_model"
|
const sd_model_setting_key = "stable_diffusion_model"
|
||||||
const vae_model_setting_key = "vae_model"
|
const vae_model_setting_key = "vae_model"
|
||||||
|
const hypernetwork_model_key = "hypernetwork_model"
|
||||||
const selectedSDModel = SETTINGS[sd_model_setting_key].value
|
const selectedSDModel = SETTINGS[sd_model_setting_key].value
|
||||||
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
|
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
|
||||||
|
const selectedHypernetworkModel = SETTINGS[hypernetwork_model_key].value
|
||||||
|
|
||||||
const models = await SD.getModels()
|
const models = await SD.getModels()
|
||||||
const modelsOptions = models['options']
|
const modelsOptions = models['options']
|
||||||
if ( "scan-error" in models) {
|
if ("scan-error" in models) {
|
||||||
// let previewPane = document.getElementById('tab-content-wrapper')
|
// let previewPane = document.getElementById('tab-content-wrapper')
|
||||||
let previewPane = document.getElementById('preview')
|
let previewPane = document.getElementById('preview')
|
||||||
previewPane.style.background="red"
|
previewPane.style.background="red"
|
||||||
@ -1154,7 +1185,10 @@ async function getModels() {
|
|||||||
|
|
||||||
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
||||||
const vaeOptions = modelsOptions['vae']
|
const vaeOptions = modelsOptions['vae']
|
||||||
|
const hypernetworkOptions = modelOptions['hypernetwork']
|
||||||
|
|
||||||
vaeOptions.unshift('') // add a None option
|
vaeOptions.unshift('') // add a None option
|
||||||
|
hypernetworkOptions.unshift('') // add a None option
|
||||||
|
|
||||||
function createModelOptions(modelField, selectedModel) {
|
function createModelOptions(modelField, selectedModel) {
|
||||||
return function(modelName) {
|
return function(modelName) {
|
||||||
@ -1172,6 +1206,7 @@ async function getModels() {
|
|||||||
|
|
||||||
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
|
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
|
||||||
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
|
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
|
||||||
|
hypernetworkOptions.forEach(createModelOptions(hypernetworkModelField, selectedHypernetworkModel))
|
||||||
|
|
||||||
// TODO: set default for model here too
|
// TODO: set default for model here too
|
||||||
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
||||||
|
@ -23,6 +23,8 @@ class Request:
|
|||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
use_vae_model: str = None
|
use_vae_model: str = None
|
||||||
|
use_hypernetwork_model: str = None
|
||||||
|
hypernetwork_strength: float = 1
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
@ -38,6 +40,7 @@ class Request:
|
|||||||
"num_outputs": self.num_outputs,
|
"num_outputs": self.num_outputs,
|
||||||
"num_inference_steps": self.num_inference_steps,
|
"num_inference_steps": self.num_inference_steps,
|
||||||
"guidance_scale": self.guidance_scale,
|
"guidance_scale": self.guidance_scale,
|
||||||
|
"hypernetwork_strengtgh": self.guidance_scale,
|
||||||
"width": self.width,
|
"width": self.width,
|
||||||
"height": self.height,
|
"height": self.height,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
@ -47,6 +50,8 @@ class Request:
|
|||||||
"use_upscale": self.use_upscale,
|
"use_upscale": self.use_upscale,
|
||||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||||
"use_vae_model": self.use_vae_model,
|
"use_vae_model": self.use_vae_model,
|
||||||
|
"use_hypernetwork_model": self.use_hypernetwork_model,
|
||||||
|
"hypernetwork_strength": self.hypernetwork_strength,
|
||||||
"output_format": self.output_format,
|
"output_format": self.output_format,
|
||||||
"output_quality": self.output_quality,
|
"output_quality": self.output_quality,
|
||||||
}
|
}
|
||||||
@ -70,6 +75,8 @@ class Request:
|
|||||||
use_upscale: {self.use_upscale}
|
use_upscale: {self.use_upscale}
|
||||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||||
use_vae_model: {self.use_vae_model}
|
use_vae_model: {self.use_vae_model}
|
||||||
|
use_hypernetwork_model: {self.use_hypernetwork_model}
|
||||||
|
hypernetwork_strength: {self.hypernetwork_strength}
|
||||||
show_only_filtered_image: {self.show_only_filtered_image}
|
show_only_filtered_image: {self.show_only_filtered_image}
|
||||||
output_format: {self.output_format}
|
output_format: {self.output_format}
|
||||||
output_quality: {self.output_quality}
|
output_quality: {self.output_quality}
|
||||||
|
198
ui/sd_internal/hypernetwork.py
Normal file
198
ui/sd_internal/hypernetwork.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
|
||||||
|
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import torch
|
||||||
|
import optimizedSD.splitAttention
|
||||||
|
from . import runtime
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||||
|
|
||||||
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
multiplier = 0.5
|
||||||
|
activation_dict = {
|
||||||
|
"linear": torch.nn.Identity,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
|
}
|
||||||
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||||
|
|
||||||
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||||
|
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
|
||||||
|
linears = []
|
||||||
|
for i in range(len(layer_structure) - 1):
|
||||||
|
|
||||||
|
# Add a fully-connected layer
|
||||||
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
# Add an activation func except last layer
|
||||||
|
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||||
|
pass
|
||||||
|
elif activation_func in self.activation_dict:
|
||||||
|
linears.append(self.activation_dict[activation_func]())
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
|
|
||||||
|
# Add layer normalization
|
||||||
|
if add_layer_norm:
|
||||||
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
# Add dropout except last layer
|
||||||
|
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||||
|
linears.append(torch.nn.Dropout(p=0.3))
|
||||||
|
|
||||||
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
|
self.fix_old_state_dict(state_dict)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
self.to(runtime.thread_data.device)
|
||||||
|
|
||||||
|
def fix_old_state_dict(self, state_dict):
|
||||||
|
changes = {
|
||||||
|
'linear1.bias': 'linear.0.bias',
|
||||||
|
'linear1.weight': 'linear.0.weight',
|
||||||
|
'linear2.bias': 'linear.1.bias',
|
||||||
|
'linear2.weight': 'linear.1.weight',
|
||||||
|
}
|
||||||
|
|
||||||
|
for fr, to in changes.items():
|
||||||
|
x = state_dict.get(fr, None)
|
||||||
|
if x is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
del state_dict[fr]
|
||||||
|
state_dict[to] = x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
|
||||||
|
|
||||||
|
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||||
|
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
|
||||||
|
|
||||||
|
if hypernetwork_layers is None:
|
||||||
|
return context, context
|
||||||
|
|
||||||
|
if layer is not None:
|
||||||
|
layer.hyper_k = hypernetwork_layers[0]
|
||||||
|
layer.hyper_v = hypernetwork_layers[1]
|
||||||
|
|
||||||
|
context_k = hypernetwork_layers[0](context)
|
||||||
|
context_v = hypernetwork_layers[1](context)
|
||||||
|
return context_k, context_v
|
||||||
|
|
||||||
|
def get_kv(context, hypernetwork):
|
||||||
|
if hypernetwork is None:
|
||||||
|
return context, context
|
||||||
|
else:
|
||||||
|
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
|
||||||
|
|
||||||
|
# This might need updating as the optimisedSD code changes
|
||||||
|
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
|
||||||
|
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
|
||||||
|
def new_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
# default context
|
||||||
|
context = context if context is not None else x() if inspect.isfunction(x) else x
|
||||||
|
# hypernetwork!
|
||||||
|
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
|
||||||
|
k = self.to_k(context_k)
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
|
||||||
|
limit = k.shape[0]
|
||||||
|
att_step = self.att_step
|
||||||
|
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||||
|
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||||
|
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||||
|
|
||||||
|
q_chunks.reverse()
|
||||||
|
k_chunks.reverse()
|
||||||
|
v_chunks.reverse()
|
||||||
|
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
del k, q, v
|
||||||
|
for i in range (0, limit, att_step):
|
||||||
|
|
||||||
|
q_buffer = q_chunks.pop()
|
||||||
|
k_buffer = k_chunks.pop()
|
||||||
|
v_buffer = v_chunks.pop()
|
||||||
|
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||||
|
|
||||||
|
del k_buffer, q_buffer
|
||||||
|
# attention, what we cannot get enough of, by chunks
|
||||||
|
|
||||||
|
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||||
|
|
||||||
|
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||||
|
del v_buffer
|
||||||
|
sim[i:i+att_step,:,:] = sim_buffer
|
||||||
|
|
||||||
|
del sim_buffer
|
||||||
|
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(sim)
|
||||||
|
|
||||||
|
|
||||||
|
def load_hypernetwork(path: str):
|
||||||
|
|
||||||
|
state_dict = torch.load(path, map_location='cpu')
|
||||||
|
|
||||||
|
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
|
activation_func = state_dict.get('activation_func', None)
|
||||||
|
weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||||
|
add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
use_dropout = state_dict.get('use_dropout', False)
|
||||||
|
activate_output = state_dict.get('activate_output', True)
|
||||||
|
last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||||
|
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
|
||||||
|
# print(f"layer_structure: {layer_structure}")
|
||||||
|
# print(f"activation_func: {activation_func}")
|
||||||
|
# print(f"weight_init: {weight_init}")
|
||||||
|
# print(f"add_layer_norm: {add_layer_norm}")
|
||||||
|
# print(f"use_dropout: {use_dropout}")
|
||||||
|
# print(f"activate_output: {activate_output}")
|
||||||
|
# print(f"last_layer_dropout: {last_layer_dropout}")
|
||||||
|
|
||||||
|
layers = {}
|
||||||
|
for size, sd in state_dict.items():
|
||||||
|
if type(size) == int:
|
||||||
|
layers[size] = (
|
||||||
|
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||||
|
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||||
|
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
|
||||||
|
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
|
||||||
|
)
|
||||||
|
print(f"hypernetwork loaded")
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# overriding of original function
|
||||||
|
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
|
||||||
|
# hijacks the cross attention forward function to add hyper network support
|
||||||
|
def hijack_cross_attention():
|
||||||
|
print("hypernetwork functionality added to cross attention")
|
||||||
|
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
|
||||||
|
# there was a cop on board
|
||||||
|
def unhijack_cross_attention_forward():
|
||||||
|
print("hypernetwork functionality removed from cross attention")
|
||||||
|
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
|
||||||
|
|
||||||
|
hijack_cross_attention()
|
@ -28,6 +28,9 @@ from gfpgan import GFPGANer
|
|||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
from . import hypernetwork
|
||||||
|
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
|
||||||
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@ -57,12 +60,15 @@ def thread_init(device):
|
|||||||
|
|
||||||
thread_data.ckpt_file = None
|
thread_data.ckpt_file = None
|
||||||
thread_data.vae_file = None
|
thread_data.vae_file = None
|
||||||
|
thread_data.hypernetwork_file = None
|
||||||
thread_data.gfpgan_file = None
|
thread_data.gfpgan_file = None
|
||||||
thread_data.real_esrgan_file = None
|
thread_data.real_esrgan_file = None
|
||||||
|
|
||||||
thread_data.model = None
|
thread_data.model = None
|
||||||
thread_data.modelCS = None
|
thread_data.modelCS = None
|
||||||
thread_data.modelFS = None
|
thread_data.modelFS = None
|
||||||
|
thread_data.hypernetwork = None
|
||||||
|
thread_data.hypernetwork_strength = 1
|
||||||
thread_data.model_gfpgan = None
|
thread_data.model_gfpgan = None
|
||||||
thread_data.model_real_esrgan = None
|
thread_data.model_real_esrgan = None
|
||||||
|
|
||||||
@ -72,6 +78,8 @@ def thread_init(device):
|
|||||||
thread_data.device_name = None
|
thread_data.device_name = None
|
||||||
thread_data.unet_bs = 1
|
thread_data.unet_bs = 1
|
||||||
thread_data.precision = 'autocast'
|
thread_data.precision = 'autocast'
|
||||||
|
thread_data.sampler_plms = None
|
||||||
|
thread_data.sampler_ddim = None
|
||||||
|
|
||||||
thread_data.turbo = False
|
thread_data.turbo = False
|
||||||
thread_data.force_full_precision = False
|
thread_data.force_full_precision = False
|
||||||
@ -433,6 +441,49 @@ def reload_model():
|
|||||||
unload_filters()
|
unload_filters()
|
||||||
load_model_ckpt()
|
load_model_ckpt()
|
||||||
|
|
||||||
|
def is_hypernetwork_reload_necessary(req: Request):
|
||||||
|
needs_model_reload = False
|
||||||
|
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
|
||||||
|
thread_data.hypernetwork_file = req.use_hypernetwork_model
|
||||||
|
needs_model_reload = True
|
||||||
|
|
||||||
|
return needs_model_reload
|
||||||
|
|
||||||
|
def load_hypernetwork():
|
||||||
|
if thread_data.hypernetwork_file is not None:
|
||||||
|
try:
|
||||||
|
loaded = False
|
||||||
|
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
|
||||||
|
if os.path.exists(thread_data.hypernetwork_file + model_extension):
|
||||||
|
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
|
||||||
|
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
|
||||||
|
loaded = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not loaded:
|
||||||
|
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
|
||||||
|
thread_data.hypernetwork_file = None
|
||||||
|
except:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
|
||||||
|
thread_data.hypernetwork_file = None
|
||||||
|
|
||||||
|
def unload_hypernetwork():
|
||||||
|
if thread_data.hypernetwork is not None:
|
||||||
|
print('Unloading hypernetwork...')
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
for i in thread_data.hypernetwork:
|
||||||
|
thread_data.hypernetwork[i][0].to('cpu')
|
||||||
|
thread_data.hypernetwork[i][1].to('cpu')
|
||||||
|
del thread_data.hypernetwork
|
||||||
|
thread_data.hypernetwork = None
|
||||||
|
|
||||||
|
gc()
|
||||||
|
|
||||||
|
def reload_hypernetwork():
|
||||||
|
unload_hypernetwork()
|
||||||
|
load_hypernetwork()
|
||||||
|
|
||||||
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
||||||
@ -509,6 +560,7 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
|
|||||||
res = Response()
|
res = Response()
|
||||||
res.request = req
|
res.request = req
|
||||||
res.images = []
|
res.images = []
|
||||||
|
thread_data.hypernetwork_strength = req.hypernetwork_strength
|
||||||
|
|
||||||
thread_data.temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
|
||||||
@ -751,6 +803,8 @@ Sampler: {req.sampler}
|
|||||||
Negative Prompt: {req.negative_prompt}
|
Negative Prompt: {req.negative_prompt}
|
||||||
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||||
VAE model: {req.use_vae_model}
|
VAE model: {req.use_vae_model}
|
||||||
|
Hypernetwork Model: {req.use_hypernetwork_model}
|
||||||
|
Hypernetwork Strength: {req.hypernetwork_strength}
|
||||||
'''
|
'''
|
||||||
try:
|
try:
|
||||||
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
||||||
|
@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
|
|||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
use_vae_model: str = None
|
use_vae_model: str = None
|
||||||
|
use_hypernetwork_model: str = None
|
||||||
|
hypernetwork_strength: float = None
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
@ -177,28 +179,35 @@ current_state = ServerStates.Init
|
|||||||
current_state_error:Exception = None
|
current_state_error:Exception = None
|
||||||
current_model_path = None
|
current_model_path = None
|
||||||
current_vae_path = None
|
current_vae_path = None
|
||||||
|
current_hypernetwork_path = None
|
||||||
tasks_queue = []
|
tasks_queue = []
|
||||||
task_cache = TaskCache()
|
task_cache = TaskCache()
|
||||||
default_model_to_load = None
|
default_model_to_load = None
|
||||||
default_vae_to_load = None
|
default_vae_to_load = None
|
||||||
|
default_hypernetwork_to_load = None
|
||||||
weak_thread_data = weakref.WeakKeyDictionary()
|
weak_thread_data = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
def preload_model(ckpt_file_path=None, vae_file_path=None):
|
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
||||||
global current_state, current_state_error, current_model_path, current_vae_path
|
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||||
if ckpt_file_path == None:
|
if ckpt_file_path == None:
|
||||||
ckpt_file_path = default_model_to_load
|
ckpt_file_path = default_model_to_load
|
||||||
if vae_file_path == None:
|
if vae_file_path == None:
|
||||||
vae_file_path = default_vae_to_load
|
vae_file_path = default_vae_to_load
|
||||||
|
if hypernetwork_file_path == None:
|
||||||
|
hypernetwork_file_path = default_hypernetwork_to_load
|
||||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||||
return
|
return
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
try:
|
try:
|
||||||
from . import runtime
|
from . import runtime
|
||||||
|
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
||||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||||
runtime.thread_data.vae_file = vae_file_path
|
runtime.thread_data.vae_file = vae_file_path
|
||||||
runtime.load_model_ckpt()
|
runtime.load_model_ckpt()
|
||||||
|
runtime.load_hypernetwork()
|
||||||
current_model_path = ckpt_file_path
|
current_model_path = ckpt_file_path
|
||||||
current_vae_path = vae_file_path
|
current_vae_path = vae_file_path
|
||||||
|
current_hypernetwork_path = hypernetwork_file_path
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -240,7 +249,7 @@ def thread_get_next_task():
|
|||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
|
||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error, current_model_path, current_vae_path
|
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||||
from . import runtime
|
from . import runtime
|
||||||
try:
|
try:
|
||||||
runtime.thread_init(device)
|
runtime.thread_init(device)
|
||||||
@ -285,6 +294,10 @@ def thread_render(device):
|
|||||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
|
if runtime.is_hypernetwork_reload_necessary(task.request):
|
||||||
|
runtime.reload_hypernetwork()
|
||||||
|
current_hypernetwork_path = task.request.use_hypernetwork_model
|
||||||
|
|
||||||
if runtime.is_model_reload_necessary(task.request):
|
if runtime.is_model_reload_necessary(task.request):
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
runtime.reload_model()
|
runtime.reload_model()
|
||||||
@ -504,6 +517,8 @@ def render(req : ImageRequest):
|
|||||||
r.use_face_correction = req.use_face_correction
|
r.use_face_correction = req.use_face_correction
|
||||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||||
r.use_vae_model = req.use_vae_model
|
r.use_vae_model = req.use_vae_model
|
||||||
|
r.use_hypernetwork_model = req.use_hypernetwork_model
|
||||||
|
r.hypernetwork_strength = req.hypernetwork_strength
|
||||||
r.show_only_filtered_image = req.show_only_filtered_image
|
r.show_only_filtered_image = req.show_only_filtered_image
|
||||||
r.output_format = req.output_format
|
r.output_format = req.output_format
|
||||||
r.output_quality = req.output_quality
|
r.output_quality = req.output_quality
|
||||||
|
20
ui/server.py
20
ui/server.py
@ -26,6 +26,7 @@ UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user
|
|||||||
|
|
||||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||||
|
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
@ -193,6 +194,12 @@ def resolve_vae_to_use(model_name:str=None):
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def resolve_hypernetwork_to_use(model_name:str=None):
|
||||||
|
try:
|
||||||
|
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
class SetAppConfigRequest(BaseModel):
|
||||||
update_branch: str = None
|
update_branch: str = None
|
||||||
render_devices: Union[List[str], List[int], str, int] = None
|
render_devices: Union[List[str], List[int], str, int] = None
|
||||||
@ -253,10 +260,12 @@ def getModels():
|
|||||||
'active': {
|
'active': {
|
||||||
'stable-diffusion': 'sd-v1-4',
|
'stable-diffusion': 'sd-v1-4',
|
||||||
'vae': '',
|
'vae': '',
|
||||||
|
'hypernetwork': '',
|
||||||
},
|
},
|
||||||
'options': {
|
'options': {
|
||||||
'stable-diffusion': ['sd-v1-4'],
|
'stable-diffusion': ['sd-v1-4'],
|
||||||
'vae': [],
|
'vae': [],
|
||||||
|
'hypernetwork': [],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,7 +297,7 @@ def getModels():
|
|||||||
# custom models
|
# custom models
|
||||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||||
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
||||||
|
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||||
if os.path.exists(custom_weight_path):
|
if os.path.exists(custom_weight_path):
|
||||||
@ -363,16 +372,19 @@ def ping(session_id:str=None):
|
|||||||
response['devices'] = task_manager.get_devices()
|
response['devices'] = task_manager.get_devices()
|
||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
def save_model_to_config(ckpt_model_name, vae_model_name):
|
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if 'model' not in config:
|
if 'model' not in config:
|
||||||
config['model'] = {}
|
config['model'] = {}
|
||||||
|
|
||||||
config['model']['stable-diffusion'] = ckpt_model_name
|
config['model']['stable-diffusion'] = ckpt_model_name
|
||||||
config['model']['vae'] = vae_model_name
|
config['model']['vae'] = vae_model_name
|
||||||
|
config['model']['hypernetwork'] = hypernetwork_model_name
|
||||||
|
|
||||||
if vae_model_name is None or vae_model_name == "":
|
if vae_model_name is None or vae_model_name == "":
|
||||||
del config['model']['vae']
|
del config['model']['vae']
|
||||||
|
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
||||||
|
del config['model']['hypernetwork']
|
||||||
|
|
||||||
setConfig(config)
|
setConfig(config)
|
||||||
|
|
||||||
@ -388,9 +400,10 @@ def update_render_devices_in_config(config, render_devices):
|
|||||||
@app.post('/render')
|
@app.post('/render')
|
||||||
def render(req : task_manager.ImageRequest):
|
def render(req : task_manager.ImageRequest):
|
||||||
try:
|
try:
|
||||||
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model)
|
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||||
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
||||||
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
|
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
|
||||||
|
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
|
||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(req)
|
||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
'status': str(task_manager.current_state),
|
||||||
@ -469,6 +482,7 @@ getModels()
|
|||||||
# Start the task_manager
|
# Start the task_manager
|
||||||
task_manager.default_model_to_load = resolve_ckpt_to_use()
|
task_manager.default_model_to_load = resolve_ckpt_to_use()
|
||||||
task_manager.default_vae_to_load = resolve_vae_to_use()
|
task_manager.default_vae_to_load = resolve_vae_to_use()
|
||||||
|
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
|
||||||
|
|
||||||
def update_render_threads():
|
def update_render_threads():
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
|
Loading…
Reference in New Issue
Block a user