Merge branch 'beta' into confirm

This commit is contained in:
cmdr2 2022-11-30 13:47:08 +05:30 committed by GitHub
commit 0b96fa112d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 729 additions and 376 deletions

View File

@ -22,7 +22,9 @@
- Ask for a confimation before clearing the results pane or stopping a render task. The dialog can be skipped by holding down the shift key while clicking on the button.
### Detailed changelog
* 2.4.14 - 22 Nov 2022 - shiftOrConfirm for red buttons
* 2.4.17 - 30 Nov 2022 - Confirm before stopping or clearing all the tasks
* 2.4.16 - 29 Nov 2022 - Bug fixes for SD 2.0 - remove the need for patching, default to SD 1.4 model if trying to load an SD2 model in SD1.4.
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.

View File

@ -29,6 +29,18 @@ call conda activate .\stable-diffusion\env
call where python
call python --version
@rem set the PYTHONPATH
cd stable-diffusion
set SD_DIR=%cd%
cd env\lib\site-packages
set PYTHONPATH=%SD_DIR%;%cd%
cd ..\..\..
echo PYTHONPATH=%PYTHONPATH%
cd ..
@rem done
echo.
cmd /k

View File

@ -35,6 +35,15 @@ if [ "$0" == "bash" ]; then
which python
python --version
# set the PYTHONPATH
cd stable-diffusion
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"
cd ..
# done
echo ""
else
file_name=$(basename "${BASH_SOURCE[0]}")

View File

@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.."
@ -37,9 +39,13 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call git reset --hard
@call git pull
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
)
@cd ..
) else (
@ -56,8 +62,6 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
@cd ..
)
@ -346,7 +350,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
)
)
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (

View File

@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then
export test_sd2="N"
fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.."
@ -30,9 +34,12 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git reset --hard
git pull
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
fi
cd ..
else
@ -47,8 +54,6 @@ else
cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
cd ..
fi
@ -291,6 +296,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi
fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt

View File

@ -24,7 +24,7 @@
<div id="logo">
<h1>
Stable Diffusion UI
<small>v2.4.14 <span id="updateBranchLabel"></span></small>
<small>v2.4.16 <span id="updateBranchLabel"></span></small>
</h1>
</div>
<div id="server-status">

View File

@ -210,7 +210,7 @@ code {
}
.collapsible-content {
display: block;
padding-left: 15px;
padding-left: 10px;
}
.collapsible-content h5 {
padding: 5pt 0pt;
@ -658,11 +658,15 @@ input::file-selector-button {
opacity: 1;
}
/* MOBILE SUPPORT */
@media screen and (max-width: 700px) {
/* Small screens */
@media screen and (max-width: 1265px) {
#top-nav {
flex-direction: column;
}
}
/* MOBILE SUPPORT */
@media screen and (max-width: 700px) {
body {
margin: 0px;
}
@ -712,7 +716,7 @@ input::file-selector-button {
padding-right: 0px;
}
#server-status {
display: none;
top: 75%;
}
.popup > div {
padding-left: 5px !important;
@ -730,6 +734,15 @@ input::file-selector-button {
}
}
@media screen and (max-width: 500px) {
#server-status #server-status-msg {
display: none;
}
#server-status:hover #server-status-msg {
display: inline;
}
}
@media (min-width: 700px) {
/* #editor {
max-width: 480px;

View File

@ -90,9 +90,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
if (activeTags.map(x => x.name).includes(modifierName)) {
// remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName)
modifierCard.classList.remove(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
toggleCardState(modifierCard, false)
} else {
// add modifier to active array
activeTags.push({
@ -101,10 +99,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
'originElement': modifierCard,
'previews': modifierPreviews
})
modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
toggleCardState(modifierCard, true)
}
refreshTagsList()
@ -222,8 +217,7 @@ function refreshTagsList() {
let idx = activeTags.indexOf(tag)
if (idx !== -1 && activeTags[idx].originElement !== undefined) {
activeTags[idx].originElement.classList.remove(activeCardClass)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
toggleCardState(activeTags[idx].originElement, false)
activeTags.splice(idx, 1)
refreshTagsList()
@ -236,6 +230,16 @@ function refreshTagsList() {
editorModifierTagsList.appendChild(brk)
}
function toggleCardState(card, makeActive) {
if (makeActive) {
card.classList.add(activeCardClass)
card.querySelector('.modifier-card-image-overlay').innerText = '-'
} else {
card.classList.remove(activeCardClass)
card.querySelector('.modifier-card-image-overlay').innerText = '+'
}
}
function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img')

View File

@ -960,10 +960,10 @@ function getPrompts() {
prompts = prompts.filter(prompt => prompt !== '')
if (activeTags.length > 0) {
const promptTags = activeTags.map(x => x.name).join(", ")
prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`)
const promptTags = activeTags.map(x => x.name).join(", ")
prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`)
}
let promptsToMake = applySetOperator(prompts)
promptsToMake = applyPermuteOperator(promptsToMake)

View File

@ -5,9 +5,9 @@
*/
var ParameterType = {
checkbox: "checkbox",
select: "select",
select_multiple: "select_multiple",
custom: "custom",
select: "select",
select_multiple: "select_multiple",
custom: "custom",
};
/**
@ -23,174 +23,182 @@
/** @type {Array.<Parameter>} */
var PARAMETERS = [
{
id: "theme",
type: ParameterType.select,
label: "Theme",
default: "theme-default",
note: "customize the look and feel of the ui",
options: [ // Note: options expanded dynamically
{
value: "theme-default",
label: "Default"
}
],
icon: "fa-palette"
},
{
id: "save_to_disk",
type: ParameterType.checkbox,
label: "Auto-Save Images",
note: "automatically saves images to the specified location",
icon: "fa-download",
default: false,
},
{
id: "diskPath",
type: ParameterType.custom,
label: "Save Location",
render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>`
}
},
{
id: "sound_toggle",
type: ParameterType.checkbox,
label: "Enable Sound",
note: "plays a sound on task completion",
icon: "fa-volume-low",
default: true,
},
{
id: "ui_open_browser_on_start",
type: ParameterType.checkbox,
label: "Open browser on startup",
note: "starts the default browser on startup",
icon: "fa-window-restore",
default: true,
},
{
id: "turbo",
type: ParameterType.checkbox,
label: "Turbo Mode",
note: "generates images faster, but uses an additional 1 GB of GPU memory",
icon: "fa-forward",
default: true,
},
{
id: "use_cpu",
type: ParameterType.checkbox,
label: "Use CPU (not GPU)",
note: "warning: this will be *very* slow",
icon: "fa-microchip",
default: false,
},
{
id: "auto_pick_gpus",
type: ParameterType.checkbox,
label: "Automatically pick the GPUs (experimental)",
default: false,
},
{
id: "use_gpus",
type: ParameterType.select_multiple,
label: "GPUs to use (experimental)",
note: "to process in parallel",
default: false,
},
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{
id: "auto_save_settings",
type: ParameterType.checkbox,
label: "Auto-Save Settings",
note: "restores settings on browser load",
icon: "fa-gear",
default: true,
},
{
id: "confirm_dangerous_actions",
type: ParameterType.checkbox,
label: "Confirm dangerous actions",
note: "Actions that might lead to data loss must either be clicked with the shift key pressed, or confirmed in an 'are you sure?' dialog",
icon: "fa-check-double",
default: true,
},
{
id: "listen_to_network",
type: ParameterType.checkbox,
label: "Make Stable Diffusion available on your network",
note: "Other devices on your network can access this web page",
icon: "fa-network-wired",
default: true,
},
{
id: "listen_port",
type: ParameterType.custom,
label: "Network port",
note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'",
icon: "fa-anchor",
render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
}
},
{
id: "use_beta_channel",
type: ParameterType.checkbox,
label: "Beta channel",
note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{
id: "theme",
type: ParameterType.select,
label: "Theme",
default: "theme-default",
note: "customize the look and feel of the ui",
options: [ // Note: options expanded dynamically
{
value: "theme-default",
label: "Default"
}
],
icon: "fa-palette"
},
{
id: "save_to_disk",
type: ParameterType.checkbox,
label: "Auto-Save Images",
note: "automatically saves images to the specified location",
icon: "fa-download",
default: false,
},
{
id: "diskPath",
type: ParameterType.custom,
label: "Save Location",
render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>`
}
},
{
id: "sound_toggle",
type: ParameterType.checkbox,
label: "Enable Sound",
note: "plays a sound on task completion",
icon: "fa-volume-low",
default: true,
},
{
id: "ui_open_browser_on_start",
type: ParameterType.checkbox,
label: "Open browser on startup",
note: "starts the default browser on startup",
icon: "fa-window-restore",
default: true,
},
{
id: "turbo",
type: ParameterType.checkbox,
label: "Turbo Mode",
note: "generates images faster, but uses an additional 1 GB of GPU memory",
icon: "fa-forward",
default: true,
},
{
id: "use_cpu",
type: ParameterType.checkbox,
label: "Use CPU (not GPU)",
note: "warning: this will be *very* slow",
icon: "fa-microchip",
default: false,
},
{
id: "auto_pick_gpus",
type: ParameterType.checkbox,
label: "Automatically pick the GPUs (experimental)",
default: false,
},
{
id: "use_gpus",
type: ParameterType.select_multiple,
label: "GPUs to use (experimental)",
note: "to process in parallel",
default: false,
},
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{
id: "auto_save_settings",
type: ParameterType.checkbox,
label: "Auto-Save Settings",
note: "restores settings on browser load",
icon: "fa-gear",
default: true,
},
{
id: "confirm_dangerous_actions",
type: ParameterType.checkbox,
label: "Confirm dangerous actions",
note: "Actions that might lead to data loss must either be clicked with the shift key pressed, or confirmed in an 'Are you sure?' dialog",
icon: "fa-check-double",
default: true,
},
{
id: "listen_to_network",
type: ParameterType.checkbox,
label: "Make Stable Diffusion available on your network",
note: "Other devices on your network can access this web page",
icon: "fa-network-wired",
default: true,
},
{
id: "listen_port",
type: ParameterType.custom,
label: "Network port",
note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'",
icon: "fa-anchor",
render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
}
},
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{
id: "use_beta_channel",
type: ParameterType.checkbox,
label: "Beta channel",
note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
];
function getParameterSettingsEntry(id) {
let parameter = PARAMETERS.filter(p => p.id === id)
if (parameter.length === 0) {
return
}
return parameter[0].settingsEntry
let parameter = PARAMETERS.filter(p => p.id === id)
if (parameter.length === 0) {
return
}
return parameter[0].settingsEntry
}
function getParameterElement(parameter) {
switch (parameter.type) {
case ParameterType.checkbox:
var is_checked = parameter.default ? " checked" : "";
return `<input id="${parameter.id}" name="${parameter.id}"${is_checked} type="checkbox">`
case ParameterType.select:
case ParameterType.select_multiple:
var options = (parameter.options || []).map(option => `<option value="${option.value}">${option.label}</option>`).join("")
var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '')
return `<select id="${parameter.id}" name="${parameter.id}" ${multiple}>${options}</select>`
case ParameterType.custom:
return parameter.render(parameter)
default:
console.error(`Invalid type for parameter ${parameter.id}`);
return "ERROR: Invalid Type"
}
switch (parameter.type) {
case ParameterType.checkbox:
var is_checked = parameter.default ? " checked" : "";
return `<input id="${parameter.id}" name="${parameter.id}"${is_checked} type="checkbox">`
case ParameterType.select:
case ParameterType.select_multiple:
var options = (parameter.options || []).map(option => `<option value="${option.value}">${option.label}</option>`).join("")
var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '')
return `<select id="${parameter.id}" name="${parameter.id}" ${multiple}>${options}</select>`
case ParameterType.custom:
return parameter.render(parameter)
default:
console.error(`Invalid type for parameter ${parameter.id}`);
return "ERROR: Invalid Type"
}
}
let parametersTable = document.querySelector("#system-settings .parameters-table")
/* fill in the system settings popup table */
function initParameters() {
PARAMETERS.forEach(parameter => {
var element = getParameterElement(parameter)
var note = parameter.note ? `<small>${parameter.note}</small>` : "";
var icon = parameter.icon ? `<i class="fa ${parameter.icon}"></i>` : "";
var newrow = document.createElement('div')
newrow.innerHTML = `
<div>${icon}</div>
<div><label for="${parameter.id}">${parameter.label}</label>${note}</div>
<div>${element}</div>`
parametersTable.appendChild(newrow)
parameter.settingsEntry = newrow
})
PARAMETERS.forEach(parameter => {
var element = getParameterElement(parameter)
var note = parameter.note ? `<small>${parameter.note}</small>` : "";
var icon = parameter.icon ? `<i class="fa ${parameter.icon}"></i>` : "";
var newrow = document.createElement('div')
newrow.innerHTML = `
<div>${icon}</div>
<div><label for="${parameter.id}">${parameter.label}</label>${note}</div>
<div>${element}</div>`
parametersTable.appendChild(newrow)
parameter.settingsEntry = newrow
})
}
initParameters()
@ -204,6 +212,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions")
@ -239,12 +248,18 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false
}
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
if (config.net && config.net.listen_port !== undefined) {
listenPortField.value = config.net.listen_port
}
if ('test_sd2' in config) {
testSD2Field.checked = config['test_sd2']
}
let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
if (config.net && config.net.listen_port !== undefined) {
listenPortField.value = config.net.listen_port
}
console.log('get config status response', config)
} catch (e) {
@ -272,7 +287,6 @@ function getCurrentRenderDeviceSelection() {
useCPUField.addEventListener('click', function() {
let gpuSettingEntry = getParameterSettingsEntry('use_gpus')
let autoPickGPUSettingEntry = getParameterSettingsEntry('auto_pick_gpus')
console.log("hello", this.checked);
if (this.checked) {
gpuSettingEntry.style.display = 'none'
autoPickGPUSettingEntry.style.display = 'none'
@ -369,22 +383,23 @@ async function getDevices() {
}
saveSettingsBtn.addEventListener('click', function() {
let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main')
let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main')
if (listenPortField.value == '') {
alert('The network port field must not be empty.')
} else if (listenPortField.value<1 || listenPortField.value>65535) {
alert('The network port must be a number from 1 to 65535')
} else {
changeAppConfig({
'render_devices': getCurrentRenderDeviceSelection(),
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value
})
}
if (listenPortField.value == '') {
alert('The network port field must not be empty.')
} else if (listenPortField.value<1 || listenPortField.value>65535) {
alert('The network port must be a number from 1 to 65535')
} else {
changeAppConfig({
'render_devices': getCurrentRenderDeviceSelection(),
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value,
'test_sd2': testSD2Field.checked
})
}
saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))
saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))
})

View File

@ -1,17 +1,17 @@
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
function getNextSibling(elem, selector) {
// Get the next sibling element
var sibling = elem.nextElementSibling
// Get the next sibling element
var sibling = elem.nextElementSibling
// If there's no selector, return the first sibling
if (!selector) return sibling
// If there's no selector, return the first sibling
if (!selector) return sibling
// If the sibling matches our selector, use it
// If not, jump to the next sibling and continue the loop
while (sibling) {
if (sibling.matches(selector)) return sibling
sibling = sibling.nextElementSibling
}
// If the sibling matches our selector, use it
// If not, jump to the next sibling and continue the loop
while (sibling) {
if (sibling.matches(selector)) return sibling
sibling = sibling.nextElementSibling
}
}

View File

@ -0,0 +1,84 @@
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -7,6 +7,7 @@ Notes:
import json
import os, re
import traceback
import queue
import torch
import numpy as np
from gc import collect as gc_collect
@ -21,13 +22,14 @@ from torch import autocast
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from optimizedSD.optimUtils import split_weighted_subprompts
from transformers import logging
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from threading import Lock
import uuid
logging.set_verbosity_error()
@ -35,7 +37,7 @@ logging.set_verbosity_error()
# consts
config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]')
force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0
gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time.
# api stuff
from sd_internal import device_manager
@ -76,8 +78,24 @@ def thread_init(device):
thread_data.force_full_precision = False
thread_data.reduced_memory = True
thread_data.test_sd2 = isSD2()
device_manager.device_init(thread_data, device)
# temp hack, will remove soon
def isSD2():
try:
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return False
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('test_sd2', False)
except Exception as e:
return False
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
@ -92,6 +110,13 @@ def load_model_ckpt():
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
else:
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
@ -185,6 +210,38 @@ def load_model_ckpt():
modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''')
def load_model_ckpt_sd2():
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
config = OmegaConf.load(config_file)
verbose = False
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
thread_data.model.to(thread_data.device)
thread_data.model.eval()
del sd
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
using precision: {thread_data.precision}''')
def unload_filters():
if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
@ -204,10 +261,11 @@ def unload_models():
if thread_data.model is not None:
print('Unloading models...')
if thread_data.device != 'cpu':
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
if not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
del thread_data.model
del thread_data.modelCS
@ -253,12 +311,6 @@ def move_to_cpu(model):
def load_model_gfpgan():
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
from facexlib.detection import retinaface
retinaface.device = torch.device(thread_data.device)
print('forced retinaface.device to', thread_data.device)
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
@ -314,15 +366,23 @@ def apply_filters(filter_name, image_data, model_path=None):
image_data.to(thread_data.device)
if filter_name == 'gfpgan':
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
# This lock is only ever used here. No need to use timeout for the request. Should never deadlock.
with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting.
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
from facexlib.detection import retinaface
retinaface.device = torch.device(thread_data.device)
print('forced retinaface.device to', thread_data.device)
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
if model_path is not None and model_path != thread_data.real_esrgan_file:
@ -337,45 +397,73 @@ def apply_filters(filter_name, image_data, model_path=None):
return image_data
def mk_img(req: Request):
def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
yield from do_mk_img(req)
return do_mk_img(req, data_queue, task_temp_images, step_callback)
except Exception as e:
print(traceback.format_exc())
if thread_data.device != 'cpu':
if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
gc() # Release from memory.
yield json.dumps({
data_queue.put(json.dumps({
"status": 'failed',
"detail": str(e)
})
}))
raise e
def update_temp_img(req, x_samples):
def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
if thread_data.test_sd2:
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
buf = img_to_buffer(img, output_format='JPEG')
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None):
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
@ -394,15 +482,17 @@ def get_image_progress_generator(req, extra_props=None):
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples)
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
yield json.dumps(progress)
data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request):
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.stop_processing = False
res = Response()
@ -411,29 +501,7 @@ def do_mk_img(req: Request):
thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo:
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
@ -478,10 +546,14 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
thread_data.modelFS.to(thread_data.device)
if not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if thread_data.test_sd2:
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
else:
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
@ -493,7 +565,8 @@ def do_mk_img(req: Request):
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu')
move_to_cpu(thread_data.modelFS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -509,11 +582,14 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
if thread_data.reduced_memory:
if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if thread_data.test_sd2:
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
else:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -526,15 +602,21 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
if thread_data.test_sd2:
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.test_sd2:
c = thread_data.model.get_learned_conditioning(prompts)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.reduced_memory:
if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler
try:
@ -542,14 +624,7 @@ def do_mk_img(req: Request):
if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
except UserInitiatedStop:
if not hasattr(thread_data, 'partial_x_samples'):
continue
@ -562,7 +637,10 @@ def do_mk_img(req: Request):
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
if thread_data.test_sd2:
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -591,9 +669,11 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img:
img_str = img_to_base64_str(img, req.output_format)
img_buffer = img_to_buffer(img, req.output_format)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
@ -609,9 +689,11 @@ def do_mk_img(req: Request):
filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0):
filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
filtered_buffer = img_to_buffer(filtered_image, req.output_format)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path)
@ -622,14 +704,18 @@ def do_mk_img(req: Request):
# if thread_data.reduced_memory:
# unload_filters()
move_to_cpu(thread_data.modelFS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
del img_data
gc()
if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
print('Task completed')
yield json.dumps(res.json())
res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path):
try:
@ -664,51 +750,109 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu')
move_to_cpu(thread_data.modelCS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS)
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
if sampler_name == 'plms':
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
else:
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent)
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent
# decode it
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
else:
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
# decode it
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
return samples_ddim
def gc():
gc_collect()
@ -776,8 +920,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG"):
buffered = img_to_buffer(img, output_format)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG"):
buffered = BytesIO()
img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0)
img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
@ -795,3 +947,48 @@ def base64_str_to_img(img_str):
buffered = base64_str_to_buffer(img_str)
img = Image.open(buffered)
return img
def split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights

View File

@ -283,45 +283,26 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1:
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
runtime.reload_model()
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
def step_callback():
global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
buf = runtime.base64_str_to_buffer(out_obj['data'])
task.temp_images[result['output'].index(out_obj)] = buf
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
task_cache.keep(task.request.session_id, TASK_TTL)
current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
except Exception as e:
task.error = e
print(traceback.format_exc())

View File

@ -116,6 +116,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
@ -133,6 +135,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
@ -140,12 +144,19 @@ def setConfig(config):
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
config = getConfig()
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
@ -188,6 +199,7 @@ class SetAppConfigRequest(BaseModel):
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
test_sd2: bool = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
@ -208,6 +220,8 @@ async def setAppConfig(req : SetAppConfigRequest):
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try:
setConfig(config)
@ -230,9 +244,9 @@ def is_malicious_model(file_path):
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
known_models = {}
def getModels():
models = {
'active': {
@ -255,9 +269,14 @@ def getModels():
if not file.endswith(model_extension):
continue
if is_malicious_model(os.path.join(models_dir, file)):
models['scan-error'] = file
return
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
@ -435,6 +454,9 @@ class LogSuppressFilter(logging.Filter):
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()