diff --git a/CHANGES.md b/CHANGES.md index 8ce22d20..e6bf4d1a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -22,7 +22,9 @@ - Ask for a confimation before clearing the results pane or stopping a render task. The dialog can be skipped by holding down the shift key while clicking on the button. ### Detailed changelog -* 2.4.14 - 22 Nov 2022 - shiftOrConfirm for red buttons +* 2.4.17 - 30 Nov 2022 - Confirm before stopping or clearing all the tasks +* 2.4.16 - 29 Nov 2022 - Bug fixes for SD 2.0 - remove the need for patching, default to SD 1.4 model if trying to load an SD2 model in SD1.4. +* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only. * 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion * 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac * 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing. diff --git a/scripts/Developer Console.cmd b/scripts/Developer Console.cmd index 70e809a3..750e4311 100644 --- a/scripts/Developer Console.cmd +++ b/scripts/Developer Console.cmd @@ -29,6 +29,18 @@ call conda activate .\stable-diffusion\env call where python call python --version +@rem set the PYTHONPATH +cd stable-diffusion +set SD_DIR=%cd% + +cd env\lib\site-packages +set PYTHONPATH=%SD_DIR%;%cd% +cd ..\..\.. +echo PYTHONPATH=%PYTHONPATH% + +cd .. + +@rem done echo. cmd /k diff --git a/scripts/developer_console.sh b/scripts/developer_console.sh index 58344678..49e71b34 100755 --- a/scripts/developer_console.sh +++ b/scripts/developer_console.sh @@ -35,6 +35,15 @@ if [ "$0" == "bash" ]; then which python python --version + # set the PYTHONPATH + cd stable-diffusion + SD_PATH=`pwd` + export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages" + echo "PYTHONPATH=$PYTHONPATH" + cd .. + + # done + echo "" else file_name=$(basename "${BASH_SOURCE[0]}") diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 6e4ffd36..e088ed57 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" +if NOT DEFINED test_sd2 set test_sd2=N + @>nul findstr /m "sd_git_cloned" scripts\install_status.txt @if "%ERRORLEVEL%" EQU "0" ( @echo "Stable Diffusion's git repository was already installed. Updating.." @@ -37,9 +39,13 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @call git reset --hard @call git pull - @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch + if "%test_sd2%" == "N" ( + @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a + ) + if "%test_sd2%" == "Y" ( + @call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9 + ) @cd .. ) else ( @@ -56,8 +62,6 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @cd stable-diffusion @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch - @cd .. ) @@ -346,7 +350,9 @@ echo. > "..\models\vae\Put your VAE files here.txt" ) ) - +if "%test_sd2%" == "Y" ( + @call pip install open_clip_torch==2.0.2 +) @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index ff6a04d4..199a9ae8 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d # Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Note to self: Please rewrite this in Python. For the sake of your own sanity. +if [ "$test_sd2" == "" ]; then + export test_sd2="N" +fi + if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then echo "Stable Diffusion's git repository was already installed. Updating.." @@ -30,9 +34,12 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta git reset --hard git pull - git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" + if [ "$test_sd2" == "N" ]; then + git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a + elif [ "$test_sd2" == "Y" ]; then + git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9 + fi cd .. else @@ -47,8 +54,6 @@ else cd stable-diffusion git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" - cd .. fi @@ -291,6 +296,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then fi fi +if [ "$test_sd2" == "Y" ]; then + pip install open_clip_torch==2.0.2 +fi if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then echo sd_weights_downloaded >> ../scripts/install_status.txt diff --git a/ui/index.html b/ui/index.html index e28c2af0..153b18c4 100644 --- a/ui/index.html +++ b/ui/index.html @@ -24,7 +24,7 @@
diff --git a/ui/media/css/main.css b/ui/media/css/main.css index 0cf83302..b1af0b16 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -210,7 +210,7 @@ code { } .collapsible-content { display: block; - padding-left: 15px; + padding-left: 10px; } .collapsible-content h5 { padding: 5pt 0pt; @@ -658,11 +658,15 @@ input::file-selector-button { opacity: 1; } -/* MOBILE SUPPORT */ -@media screen and (max-width: 700px) { +/* Small screens */ +@media screen and (max-width: 1265px) { #top-nav { flex-direction: column; } +} + +/* MOBILE SUPPORT */ +@media screen and (max-width: 700px) { body { margin: 0px; } @@ -712,7 +716,7 @@ input::file-selector-button { padding-right: 0px; } #server-status { - display: none; + top: 75%; } .popup > div { padding-left: 5px !important; @@ -730,6 +734,15 @@ input::file-selector-button { } } +@media screen and (max-width: 500px) { + #server-status #server-status-msg { + display: none; + } + #server-status:hover #server-status-msg { + display: inline; + } +} + @media (min-width: 700px) { /* #editor { max-width: 480px; diff --git a/ui/media/js/image-modifiers.js b/ui/media/js/image-modifiers.js index 24347fc4..7ba967b4 100644 --- a/ui/media/js/image-modifiers.js +++ b/ui/media/js/image-modifiers.js @@ -90,9 +90,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) { if (activeTags.map(x => x.name).includes(modifierName)) { // remove modifier from active array activeTags = activeTags.filter(x => x.name != modifierName) - modifierCard.classList.remove(activeCardClass) - - modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+' + toggleCardState(modifierCard, false) } else { // add modifier to active array activeTags.push({ @@ -101,10 +99,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) { 'originElement': modifierCard, 'previews': modifierPreviews }) - - modifierCard.classList.add(activeCardClass) - - modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-' + toggleCardState(modifierCard, true) } refreshTagsList() @@ -222,8 +217,7 @@ function refreshTagsList() { let idx = activeTags.indexOf(tag) if (idx !== -1 && activeTags[idx].originElement !== undefined) { - activeTags[idx].originElement.classList.remove(activeCardClass) - activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+' + toggleCardState(activeTags[idx].originElement, false) activeTags.splice(idx, 1) refreshTagsList() @@ -236,6 +230,16 @@ function refreshTagsList() { editorModifierTagsList.appendChild(brk) } +function toggleCardState(card, makeActive) { + if (makeActive) { + card.classList.add(activeCardClass) + card.querySelector('.modifier-card-image-overlay').innerText = '-' + } else { + card.classList.remove(activeCardClass) + card.querySelector('.modifier-card-image-overlay').innerText = '+' + } +} + function changePreviewImages(val) { const previewImages = document.querySelectorAll('.modifier-card-image-container img') diff --git a/ui/media/js/main.js b/ui/media/js/main.js index e46c530d..a971e02b 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -960,10 +960,10 @@ function getPrompts() { prompts = prompts.filter(prompt => prompt !== '') if (activeTags.length > 0) { - const promptTags = activeTags.map(x => x.name).join(", ") - prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`) + const promptTags = activeTags.map(x => x.name).join(", ") + prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`) } - + let promptsToMake = applySetOperator(prompts) promptsToMake = applyPermuteOperator(promptsToMake) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index dfa33246..7eaee64e 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -5,9 +5,9 @@ */ var ParameterType = { checkbox: "checkbox", - select: "select", - select_multiple: "select_multiple", - custom: "custom", + select: "select", + select_multiple: "select_multiple", + custom: "custom", }; /** @@ -23,174 +23,182 @@ /** @type {Array.} */ var PARAMETERS = [ - { - id: "theme", - type: ParameterType.select, - label: "Theme", - default: "theme-default", - note: "customize the look and feel of the ui", - options: [ // Note: options expanded dynamically - { - value: "theme-default", - label: "Default" - } - ], - icon: "fa-palette" - }, - { - id: "save_to_disk", - type: ParameterType.checkbox, - label: "Auto-Save Images", - note: "automatically saves images to the specified location", - icon: "fa-download", - default: false, - }, - { - id: "diskPath", - type: ParameterType.custom, - label: "Save Location", - render: (parameter) => { - return `` - } - }, - { - id: "sound_toggle", - type: ParameterType.checkbox, - label: "Enable Sound", - note: "plays a sound on task completion", - icon: "fa-volume-low", - default: true, - }, - { - id: "ui_open_browser_on_start", - type: ParameterType.checkbox, - label: "Open browser on startup", - note: "starts the default browser on startup", - icon: "fa-window-restore", - default: true, - }, - { - id: "turbo", - type: ParameterType.checkbox, - label: "Turbo Mode", - note: "generates images faster, but uses an additional 1 GB of GPU memory", - icon: "fa-forward", - default: true, - }, - { - id: "use_cpu", - type: ParameterType.checkbox, - label: "Use CPU (not GPU)", - note: "warning: this will be *very* slow", - icon: "fa-microchip", - default: false, - }, - { - id: "auto_pick_gpus", - type: ParameterType.checkbox, - label: "Automatically pick the GPUs (experimental)", - default: false, - }, - { - id: "use_gpus", - type: ParameterType.select_multiple, - label: "GPUs to use (experimental)", - note: "to process in parallel", - default: false, - }, - { - id: "use_full_precision", - type: ParameterType.checkbox, - label: "Use Full Precision", - note: "for GPU-only. warning: this will consume more VRAM", - icon: "fa-crosshairs", - default: false, - }, - { - id: "auto_save_settings", - type: ParameterType.checkbox, - label: "Auto-Save Settings", - note: "restores settings on browser load", - icon: "fa-gear", - default: true, - }, - { - id: "confirm_dangerous_actions", - type: ParameterType.checkbox, - label: "Confirm dangerous actions", - note: "Actions that might lead to data loss must either be clicked with the shift key pressed, or confirmed in an 'are you sure?' dialog", - icon: "fa-check-double", - default: true, - }, - { - id: "listen_to_network", - type: ParameterType.checkbox, - label: "Make Stable Diffusion available on your network", - note: "Other devices on your network can access this web page", - icon: "fa-network-wired", - default: true, - }, - { - id: "listen_port", - type: ParameterType.custom, - label: "Network port", - note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'", - icon: "fa-anchor", - render: (parameter) => { - return `` - } - }, - { - id: "use_beta_channel", - type: ParameterType.checkbox, - label: "Beta channel", - note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.", - icon: "fa-fire", - default: false, - }, + { + id: "theme", + type: ParameterType.select, + label: "Theme", + default: "theme-default", + note: "customize the look and feel of the ui", + options: [ // Note: options expanded dynamically + { + value: "theme-default", + label: "Default" + } + ], + icon: "fa-palette" + }, + { + id: "save_to_disk", + type: ParameterType.checkbox, + label: "Auto-Save Images", + note: "automatically saves images to the specified location", + icon: "fa-download", + default: false, + }, + { + id: "diskPath", + type: ParameterType.custom, + label: "Save Location", + render: (parameter) => { + return `` + } + }, + { + id: "sound_toggle", + type: ParameterType.checkbox, + label: "Enable Sound", + note: "plays a sound on task completion", + icon: "fa-volume-low", + default: true, + }, + { + id: "ui_open_browser_on_start", + type: ParameterType.checkbox, + label: "Open browser on startup", + note: "starts the default browser on startup", + icon: "fa-window-restore", + default: true, + }, + { + id: "turbo", + type: ParameterType.checkbox, + label: "Turbo Mode", + note: "generates images faster, but uses an additional 1 GB of GPU memory", + icon: "fa-forward", + default: true, + }, + { + id: "use_cpu", + type: ParameterType.checkbox, + label: "Use CPU (not GPU)", + note: "warning: this will be *very* slow", + icon: "fa-microchip", + default: false, + }, + { + id: "auto_pick_gpus", + type: ParameterType.checkbox, + label: "Automatically pick the GPUs (experimental)", + default: false, + }, + { + id: "use_gpus", + type: ParameterType.select_multiple, + label: "GPUs to use (experimental)", + note: "to process in parallel", + default: false, + }, + { + id: "use_full_precision", + type: ParameterType.checkbox, + label: "Use Full Precision", + note: "for GPU-only. warning: this will consume more VRAM", + icon: "fa-crosshairs", + default: false, + }, + { + id: "auto_save_settings", + type: ParameterType.checkbox, + label: "Auto-Save Settings", + note: "restores settings on browser load", + icon: "fa-gear", + default: true, + }, + { + id: "confirm_dangerous_actions", + type: ParameterType.checkbox, + label: "Confirm dangerous actions", + note: "Actions that might lead to data loss must either be clicked with the shift key pressed, or confirmed in an 'Are you sure?' dialog", + icon: "fa-check-double", + default: true, + }, + { + id: "listen_to_network", + type: ParameterType.checkbox, + label: "Make Stable Diffusion available on your network", + note: "Other devices on your network can access this web page", + icon: "fa-network-wired", + default: true, + }, + { + id: "listen_port", + type: ParameterType.custom, + label: "Network port", + note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'", + icon: "fa-anchor", + render: (parameter) => { + return `` + } + }, + { + id: "test_sd2", + type: ParameterType.checkbox, + label: "Test SD 2.0", + note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.", + icon: "fa-fire", + default: false, + }, + { + id: "use_beta_channel", + type: ParameterType.checkbox, + label: "Beta channel", + note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.", + icon: "fa-fire", + default: false, + }, ]; function getParameterSettingsEntry(id) { - let parameter = PARAMETERS.filter(p => p.id === id) - if (parameter.length === 0) { - return - } - return parameter[0].settingsEntry + let parameter = PARAMETERS.filter(p => p.id === id) + if (parameter.length === 0) { + return + } + return parameter[0].settingsEntry } function getParameterElement(parameter) { - switch (parameter.type) { - case ParameterType.checkbox: - var is_checked = parameter.default ? " checked" : ""; - return `` - case ParameterType.select: - case ParameterType.select_multiple: - var options = (parameter.options || []).map(option => ``).join("") - var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '') - return `` - case ParameterType.custom: - return parameter.render(parameter) - default: - console.error(`Invalid type for parameter ${parameter.id}`); - return "ERROR: Invalid Type" - } + switch (parameter.type) { + case ParameterType.checkbox: + var is_checked = parameter.default ? " checked" : ""; + return `` + case ParameterType.select: + case ParameterType.select_multiple: + var options = (parameter.options || []).map(option => ``).join("") + var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '') + return `` + case ParameterType.custom: + return parameter.render(parameter) + default: + console.error(`Invalid type for parameter ${parameter.id}`); + return "ERROR: Invalid Type" + } } let parametersTable = document.querySelector("#system-settings .parameters-table") /* fill in the system settings popup table */ function initParameters() { - PARAMETERS.forEach(parameter => { - var element = getParameterElement(parameter) - var note = parameter.note ? `${parameter.note}` : ""; - var icon = parameter.icon ? `` : ""; - var newrow = document.createElement('div') - newrow.innerHTML = ` -
${icon}
-
${note}
-
${element}
` - parametersTable.appendChild(newrow) - parameter.settingsEntry = newrow - }) + PARAMETERS.forEach(parameter => { + var element = getParameterElement(parameter) + var note = parameter.note ? `${parameter.note}` : ""; + var icon = parameter.icon ? `` : ""; + var newrow = document.createElement('div') + newrow.innerHTML = ` +
${icon}
+
${note}
+
${element}
` + parametersTable.appendChild(newrow) + parameter.settingsEntry = newrow + }) } initParameters() @@ -204,6 +212,7 @@ let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") let listenPortField = document.querySelector("#listen_port") +let testSD2Field = document.querySelector("#test_sd2") let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") @@ -239,12 +248,18 @@ async function getAppConfig() { if (config.ui && config.ui.open_browser_on_start === false) { uiOpenBrowserOnStartField.checked = false } - if (config.net && config.net.listen_to_network === false) { - listenToNetworkField.checked = false - } - if (config.net && config.net.listen_port !== undefined) { - listenPortField.value = config.net.listen_port - } + if ('test_sd2' in config) { + testSD2Field.checked = config['test_sd2'] + } + + let testSD2SettingEntry = getParameterSettingsEntry('test_sd2') + testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none') + if (config.net && config.net.listen_to_network === false) { + listenToNetworkField.checked = false + } + if (config.net && config.net.listen_port !== undefined) { + listenPortField.value = config.net.listen_port + } console.log('get config status response', config) } catch (e) { @@ -272,7 +287,6 @@ function getCurrentRenderDeviceSelection() { useCPUField.addEventListener('click', function() { let gpuSettingEntry = getParameterSettingsEntry('use_gpus') let autoPickGPUSettingEntry = getParameterSettingsEntry('auto_pick_gpus') - console.log("hello", this.checked); if (this.checked) { gpuSettingEntry.style.display = 'none' autoPickGPUSettingEntry.style.display = 'none' @@ -369,22 +383,23 @@ async function getDevices() { } saveSettingsBtn.addEventListener('click', function() { - let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main') + let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main') - if (listenPortField.value == '') { - alert('The network port field must not be empty.') - } else if (listenPortField.value<1 || listenPortField.value>65535) { - alert('The network port must be a number from 1 to 65535') - } else { - changeAppConfig({ - 'render_devices': getCurrentRenderDeviceSelection(), - 'update_branch': updateBranch, - 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, - 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value - }) - } + if (listenPortField.value == '') { + alert('The network port field must not be empty.') + } else if (listenPortField.value<1 || listenPortField.value>65535) { + alert('The network port must be a number from 1 to 65535') + } else { + changeAppConfig({ + 'render_devices': getCurrentRenderDeviceSelection(), + 'update_branch': updateBranch, + 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, + 'listen_to_network': listenToNetworkField.checked, + 'listen_port': listenPortField.value, + 'test_sd2': testSD2Field.checked + }) + } - saveSettingsBtn.classList.add('active') - asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) + saveSettingsBtn.classList.add('active') + asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) }) diff --git a/ui/media/js/utils.js b/ui/media/js/utils.js index 6fc3c402..a76f030e 100644 --- a/ui/media/js/utils.js +++ b/ui/media/js/utils.js @@ -1,17 +1,17 @@ // https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/ function getNextSibling(elem, selector) { - // Get the next sibling element - var sibling = elem.nextElementSibling + // Get the next sibling element + var sibling = elem.nextElementSibling - // If there's no selector, return the first sibling - if (!selector) return sibling + // If there's no selector, return the first sibling + if (!selector) return sibling - // If the sibling matches our selector, use it - // If not, jump to the next sibling and continue the loop - while (sibling) { - if (sibling.matches(selector)) return sibling - sibling = sibling.nextElementSibling - } + // If the sibling matches our selector, use it + // If not, jump to the next sibling and continue the loop + while (sibling) { + if (sibling.matches(selector)) return sibling + sibling = sibling.nextElementSibling + } } diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch new file mode 100644 index 00000000..cadf81ca --- /dev/null +++ b/ui/sd_internal/ddim_callback_sd2.patch @@ -0,0 +1,84 @@ +diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py +index 27ead0e..6215939 100644 +--- a/ldm/models/diffusion/ddim.py ++++ b/ldm/models/diffusion/ddim.py +@@ -100,7 +100,7 @@ class DDIMSampler(object): + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + +- samples, intermediates = self.ddim_sampling(conditioning, size, ++ samples = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, +@@ -117,7 +117,8 @@ class DDIMSampler(object): + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) +- return samples, intermediates ++ # return samples, intermediates ++ yield from samples + + @torch.no_grad() + def ddim_sampling(self, cond, shape, +@@ -168,14 +169,15 @@ class DDIMSampler(object): + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs +- if callback: callback(i) +- if img_callback: img_callback(pred_x0, i) ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + +- return img, intermediates ++ # return img, intermediates ++ yield from img_callback(pred_x0, len(iterator)-1) + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, +diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py +index 7002a36..0951f39 100644 +--- a/ldm/models/diffusion/plms.py ++++ b/ldm/models/diffusion/plms.py +@@ -96,7 +96,7 @@ class PLMSSampler(object): + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + +- samples, intermediates = self.plms_sampling(conditioning, size, ++ samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, +@@ -112,7 +112,8 @@ class PLMSSampler(object): + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) +- return samples, intermediates ++ #return samples, intermediates ++ yield from samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, +@@ -165,14 +166,15 @@ class PLMSSampler(object): + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) +- if callback: callback(i) +- if img_callback: img_callback(pred_x0, i) ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + +- return img, intermediates ++ # return img, intermediates ++ yield from img_callback(pred_x0, len(iterator)-1) + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index e217965d..26c116ad 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -7,6 +7,7 @@ Notes: import json import os, re import traceback +import queue import torch import numpy as np from gc import collect as gc_collect @@ -21,13 +22,14 @@ from torch import autocast from contextlib import nullcontext from einops import rearrange, repeat from ldm.util import instantiate_from_config -from optimizedSD.optimUtils import split_weighted_subprompts from transformers import logging from gfpgan import GFPGANer from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer +from threading import Lock + import uuid logging.set_verbosity_error() @@ -35,7 +37,7 @@ logging.set_verbosity_error() # consts config_yaml = "optimizedSD/v1-inference.yaml" filename_regex = re.compile('[^a-zA-Z0-9]') -force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0 +gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. # api stuff from sd_internal import device_manager @@ -76,8 +78,24 @@ def thread_init(device): thread_data.force_full_precision = False thread_data.reduced_memory = True + thread_data.test_sd2 = isSD2() + device_manager.device_init(thread_data, device) +# temp hack, will remove soon +def isSD2(): + try: + SD_UI_DIR = os.getenv('SD_UI_PATH', None) + CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + if not os.path.exists(config_json_path): + return False + with open(config_json_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config.get('test_sd2', False) + except Exception as e: + return False + def load_model_ckpt(): if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') @@ -92,6 +110,13 @@ def load_model_ckpt(): thread_data.precision = 'full' print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) + + if thread_data.test_sd2: + load_model_ckpt_sd2() + else: + load_model_ckpt_sd1() + +def load_model_ckpt_sd1(): sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') li, lo = [], [] for key, value in sd.items(): @@ -185,6 +210,38 @@ def load_model_ckpt(): modelFS.device: {thread_data.modelFS.device} using precision: {thread_data.precision}''') +def load_model_ckpt_sd2(): + config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml" + config = OmegaConf.load(config_file) + verbose = False + + sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') + + thread_data.model = instantiate_from_config(config.model) + m, u = thread_data.model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + thread_data.model.to(thread_data.device) + thread_data.model.eval() + del sd + + if thread_data.device != "cpu" and thread_data.precision == "autocast": + thread_data.model.half() + thread_data.model_is_half = True + thread_data.model_fs_is_half = True + else: + thread_data.model_is_half = False + thread_data.model_fs_is_half = False + + print(f'''loaded model + model file: {thread_data.ckpt_file}.ckpt + using precision: {thread_data.precision}''') + def unload_filters(): if thread_data.model_gfpgan is not None: if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') @@ -204,10 +261,11 @@ def unload_models(): if thread_data.model is not None: print('Unloading models...') if thread_data.device != 'cpu': - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") + if not thread_data.test_sd2: + thread_data.modelFS.to('cpu') + thread_data.modelCS.to('cpu') + thread_data.model.model1.to("cpu") + thread_data.model.model2.to("cpu") del thread_data.model del thread_data.modelCS @@ -253,12 +311,6 @@ def move_to_cpu(model): def load_model_gfpgan(): if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') - - # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files - from facexlib.detection import retinaface - retinaface.device = torch.device(thread_data.device) - print('forced retinaface.device to', thread_data.device) - model_path = thread_data.gfpgan_file + ".pth" thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) @@ -314,15 +366,23 @@ def apply_filters(filter_name, image_data, model_path=None): image_data.to(thread_data.device) if filter_name == 'gfpgan': - if model_path is not None and model_path != thread_data.gfpgan_file: - thread_data.gfpgan_file = model_path - load_model_gfpgan() - elif not thread_data.model_gfpgan: - load_model_gfpgan() - if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') - print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - image_data = output[:,:,::-1] + # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. + with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. + # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files + from facexlib.detection import retinaface + retinaface.device = torch.device(thread_data.device) + print('forced retinaface.device to', thread_data.device) + + if model_path is not None and model_path != thread_data.gfpgan_file: + thread_data.gfpgan_file = model_path + load_model_gfpgan() + elif not thread_data.model_gfpgan: + load_model_gfpgan() + if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') + + print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) + _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + image_data = output[:,:,::-1] if filter_name == 'real_esrgan': if model_path is not None and model_path != thread_data.real_esrgan_file: @@ -337,45 +397,73 @@ def apply_filters(filter_name, image_data, model_path=None): return image_data -def mk_img(req: Request): +def is_model_reload_necessary(req: Request): + # custom model support: + # the req.use_stable_diffusion_model needs to be a valid path + # to the ckpt file (without the extension). + if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') + + needs_model_reload = False + if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: + thread_data.ckpt_file = req.use_stable_diffusion_model + thread_data.vae_file = req.use_vae_model + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload + +def reload_model(): + unload_models() + unload_filters() + load_model_ckpt() + +def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - yield from do_mk_img(req) + return do_mk_img(req, data_queue, task_temp_images, step_callback) except Exception as e: print(traceback.format_exc()) - if thread_data.device != 'cpu': + if thread_data.device != 'cpu' and not thread_data.test_sd2: thread_data.modelFS.to('cpu') thread_data.modelCS.to('cpu') thread_data.model.model1.to("cpu") thread_data.model.model2.to("cpu") gc() # Release from memory. - yield json.dumps({ + data_queue.put(json.dumps({ "status": 'failed', "detail": str(e) - }) + })) + raise e -def update_temp_img(req, x_samples): +def update_temp_img(req, x_samples, task_temp_images: list): partial_images = [] for i in range(req.num_outputs): - x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + if thread_data.test_sd2: + x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) + else: + x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) - buf = BytesIO() - img.save(buf, format='JPEG') - buf.seek(0) + buf = img_to_buffer(img, output_format='JPEG') del img, x_sample, x_sample_ddim # don't delete x_samples, it is used in the code that called this callback thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf + task_temp_images[i] = buf partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) return partial_images # Build and return the apropriate generator for do_mk_img -def get_image_progress_generator(req, extra_props=None): +def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): if not req.stream_progress_updates: def empty_callback(x_samples, i): return x_samples return empty_callback @@ -394,15 +482,17 @@ def get_image_progress_generator(req, extra_props=None): progress.update(extra_props) if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples) + progress['output'] = update_temp_img(req, x_samples, task_temp_images) - yield json.dumps(progress) + data_queue.put(json.dumps(progress)) + + step_callback() if thread_data.stop_processing: raise UserInitiatedStop("User requested that we stop processing") return img_callback -def do_mk_img(req: Request): +def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): thread_data.stop_processing = False res = Response() @@ -411,29 +501,7 @@ def do_mk_img(req: Request): thread_data.temp_images.clear() - # custom model support: - # the req.use_stable_diffusion_model needs to be a valid path - # to the ckpt file (without the extension). - if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') - - needs_model_reload = False - if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - if needs_model_reload: - unload_models() - unload_filters() - load_model_ckpt() - - if thread_data.turbo != req.turbo: + if thread_data.turbo != req.turbo and not thread_data.test_sd2: thread_data.turbo = req.turbo thread_data.model.turbo = req.turbo @@ -478,10 +546,14 @@ def do_mk_img(req: Request): if thread_data.device != "cpu" and thread_data.precision == "autocast": init_image = init_image.half() - thread_data.modelFS.to(thread_data.device) + if not thread_data.test_sd2: + thread_data.modelFS.to(thread_data.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space + if thread_data.test_sd2: + init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space + else: + init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space if req.mask is not None: mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) @@ -493,7 +565,8 @@ def do_mk_img(req: Request): # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelFS, 'cpu') - move_to_cpu(thread_data.modelFS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelFS) assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(req.prompt_strength * req.num_inference_steps) @@ -509,11 +582,14 @@ def do_mk_img(req: Request): for prompts in tqdm(data, desc="data"): with precision_scope("cuda"): - if thread_data.reduced_memory: + if thread_data.reduced_memory and not thread_data.test_sd2: thread_data.modelCS.to(thread_data.device) uc = None if req.guidance_scale != 1.0: - uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) + if thread_data.test_sd2: + uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt]) + else: + uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -526,15 +602,21 @@ def do_mk_img(req: Request): weight = weights[i] # if not skip_normalize: weight = weight / totalWeight - c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) + if thread_data.test_sd2: + c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: + c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) else: - c = thread_data.modelCS.get_learned_conditioning(prompts) + if thread_data.test_sd2: + c = thread_data.model.get_learned_conditioning(prompts) + else: + c = thread_data.modelCS.get_learned_conditioning(prompts) - if thread_data.reduced_memory: + if thread_data.reduced_memory and not thread_data.test_sd2: thread_data.modelFS.to(thread_data.device) n_steps = req.num_inference_steps if req.init_image is None else t_enc - img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) + img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps}) # run the handler try: @@ -542,14 +624,7 @@ def do_mk_img(req: Request): if handler == _txt2img: x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) else: - x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) - - if req.stream_progress_updates: - yield from x_samples - if hasattr(thread_data, 'partial_x_samples'): - if thread_data.partial_x_samples is not None: - x_samples = thread_data.partial_x_samples - del thread_data.partial_x_samples + x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) except UserInitiatedStop: if not hasattr(thread_data, 'partial_x_samples'): continue @@ -562,7 +637,10 @@ def do_mk_img(req: Request): print("decoding images") img_data = [None] * batch_size for i in range(batch_size): - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + if thread_data.test_sd2: + x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) + else: + x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) @@ -591,9 +669,11 @@ def do_mk_img(req: Request): save_metadata(meta_out_path, req, prompts[0], opt_seed) if return_orig_img: - img_str = img_to_base64_str(img, req.output_format) + img_buffer = img_to_buffer(img, req.output_format) + img_str = buffer_to_base64_str(img_buffer, req.output_format) res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res.images.append(res_image_orig) + task_temp_images[i] = img_buffer if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path @@ -609,9 +689,11 @@ def do_mk_img(req: Request): filters_applied.append(req.use_upscale) if (len(filters_applied) > 0): filtered_image = Image.fromarray(img_data[i]) - filtered_img_data = img_to_base64_str(filtered_image, req.output_format) + filtered_buffer = img_to_buffer(filtered_image, req.output_format) + filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) res.images.append(response_image) + task_temp_images[i] = filtered_buffer if req.save_to_disk_path is not None: filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) save_image(filtered_image, filtered_img_out_path) @@ -622,14 +704,18 @@ def do_mk_img(req: Request): # if thread_data.reduced_memory: # unload_filters() - move_to_cpu(thread_data.modelFS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelFS) del img_data gc() if thread_data.device != 'cpu': print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print('Task completed') - yield json.dumps(res.json()) + res = res.json() + data_queue.put(json.dumps(res)) + + return res def save_image(img, img_out_path): try: @@ -664,51 +750,109 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelCS, 'cpu') - move_to_cpu(thread_data.modelCS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelCS) - if sampler_name == 'ddim': - thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): + raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') - samples_ddim = thread_data.model.sample( - S=opt_ddim_steps, - conditioning=c, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - yield from samples_ddim -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): + # samples, _ = sampler.sample(S=opt.steps, + # conditioning=c, + # batch_size=opt.n_samples, + # shape=shape, + # verbose=False, + # unconditional_guidance_scale=opt.scale, + # unconditional_conditioning=uc, + # eta=opt.ddim_eta, + # x_T=start_code) + + if thread_data.test_sd2: + from ldm.models.diffusion.ddim import DDIMSampler + from ldm.models.diffusion.plms import PLMSSampler + + shape = [opt_C, opt_H // opt_f, opt_W // opt_f] + + if sampler_name == 'plms': + sampler = PLMSSampler(thread_data.model) + elif sampler_name == 'ddim': + sampler = DDIMSampler(thread_data.model) + + sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + + + samples_ddim, intermediates = sampler.sample( + S=opt_ddim_steps, + conditioning=c, + batch_size=opt_n_samples, + seed=opt_seed, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt_scale, + unconditional_conditioning=uc, + eta=opt_ddim_eta, + x_T=start_code, + img_callback=img_callback, + mask=mask, + sampler = sampler_name, + ) + else: + if sampler_name == 'ddim': + thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + + samples_ddim = thread_data.model.sample( + S=opt_ddim_steps, + conditioning=c, + seed=opt_seed, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt_scale, + unconditional_conditioning=uc, + eta=opt_ddim_eta, + x_T=start_code, + img_callback=img_callback, + mask=mask, + sampler = sampler_name, + ) + return samples_ddim + +def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): # encode (scaled latent) - z_enc = thread_data.model.stochastic_encode( - init_latent, - torch.tensor([t_enc] * batch_size).to(thread_data.device), - opt_seed, - opt_ddim_eta, - opt_ddim_steps, - ) x_T = None if mask is None else init_latent - # decode it - samples_ddim = thread_data.model.sample( - t_enc, - c, - z_enc, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - img_callback=img_callback, - mask=mask, - x_T=x_T, - sampler = 'ddim' - ) - yield from samples_ddim + if thread_data.test_sd2: + from ldm.models.diffusion.ddim import DDIMSampler + + sampler = DDIMSampler(thread_data.model) + + sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device)) + + samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback) + + else: + z_enc = thread_data.model.stochastic_encode( + init_latent, + torch.tensor([t_enc] * batch_size).to(thread_data.device), + opt_seed, + opt_ddim_eta, + opt_ddim_steps, + ) + + # decode it + samples_ddim = thread_data.model.sample( + t_enc, + c, + z_enc, + unconditional_guidance_scale=opt_scale, + unconditional_conditioning=uc, + img_callback=img_callback, + mask=mask, + x_T=x_T, + sampler = 'ddim' + ) + return samples_ddim def gc(): gc_collect() @@ -776,8 +920,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False): # https://stackoverflow.com/a/61114178 def img_to_base64_str(img, output_format="PNG"): + buffered = img_to_buffer(img, output_format) + return buffer_to_base64_str(buffered, output_format) + +def img_to_buffer(img, output_format="PNG"): buffered = BytesIO() img.save(buffered, format=output_format) + buffered.seek(0) + return buffered + +def buffer_to_base64_str(buffered, output_format="PNG"): buffered.seek(0) img_byte = buffered.getvalue() mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" @@ -795,3 +947,48 @@ def base64_str_to_img(img_str): buffered = base64_str_to_buffer(img_str) img = Image.open(buffered) return img + +def split_weighted_subprompts(text): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + remaining = len(text) + prompts = [] + weights = [] + while remaining > 0: + if ":" in text: + idx = text.index(":") # first occurrence from start + # grab up to index as sub-prompt + prompt = text[:idx] + remaining -= idx + # remove from main text + text = text[idx+1:] + # find value for weight + if " " in text: + idx = text.index(" ") # first occurence + else: # no space, read to end + idx = len(text) + if idx != 0: + try: + weight = float(text[:idx]) + except: # couldn't treat as float + print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") + weight = 1.0 + else: # no value found + weight = 1.0 + # remove from main text + remaining -= idx + text = text[idx+1:] + # append the sub-prompt and its weight + prompts.append(prompt) + weights.append(weight) + else: # no : found + if len(text) > 0: # there is still text though + # take remainder as weight 1 + prompts.append(text) + weights.append(1.0) + remaining = 0 + return prompts, weights diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index bd87517b..ff6cbb4c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -283,45 +283,26 @@ def thread_render(device): print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - if runtime.thread_data.device == 'cpu' and is_alive() > 1: - # CPU is not the only device. Keep track of active time to unload resources later. - runtime.thread_data.lastActive = time.time() - # Open data generator. - res = runtime.mk_img(task.request) - if current_model_path == task.request.use_stable_diffusion_model: - current_state = ServerStates.Rendering - else: + if runtime.is_model_reload_necessary(task.request): current_state = ServerStates.LoadingModel - # Start reading from generator. - dataQueue = None - if task.request.stream_progress_updates: - dataQueue = task.buffer_queue - for result in res: - if current_state == ServerStates.LoadingModel: - current_state = ServerStates.Rendering - current_model_path = task.request.use_stable_diffusion_model - current_vae_path = task.request.use_vae_model + runtime.reload_model() + current_model_path = task.request.use_stable_diffusion_model + current_vae_path = task.request.use_vae_model + + def step_callback(): + global current_state_error + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): runtime.thread_data.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') - if dataQueue: - dataQueue.put(result) - if isinstance(result, str): - result = json.loads(result) - task.response = result - if 'output' in result: - for out_obj in result['output']: - if 'path' in out_obj: - img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] - task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]] - elif 'data' in out_obj: - buf = runtime.base64_str_to_buffer(out_obj['data']) - task.temp_images[result['output'].index(out_obj)] = buf - # Before looping back to the generator, mark cache as still alive. - task_cache.keep(task.request.session_id, TASK_TTL) + + task_cache.keep(task.request.session_id, TASK_TTL) + + current_state = ServerStates.Rendering + task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) except Exception as e: task.error = e print(traceback.format_exc()) diff --git a/ui/server.py b/ui/server.py index 61635f18..fca1aca8 100644 --- a/ui/server.py +++ b/ui/server.py @@ -116,6 +116,8 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") + config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") + if len(config_bat) > 0: with open(config_bat_path, 'w', encoding='utf-8') as f: f.write('\r\n'.join(config_bat)) @@ -133,6 +135,8 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") + config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") + if len(config_sh) > 1: with open(config_sh_path, 'w', encoding='utf-8') as f: f.write('\n'.join(config_sh)) @@ -140,12 +144,19 @@ def setConfig(config): print(traceback.format_exc()) def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): + config = getConfig() + model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] if not model_name: # When None try user configured model. - config = getConfig() + # config = getConfig() if 'model' in config and model_type in config['model']: model_name = config['model'][model_type] if model_name: + is_sd2 = config.get('test_sd2', False) + if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 + print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') + model_name = 'sd-v1-4' + # Check models directory models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) for model_extension in model_extensions: @@ -188,6 +199,7 @@ class SetAppConfigRequest(BaseModel): ui_open_browser_on_start: bool = None listen_to_network: bool = None listen_port: int = None + test_sd2: bool = None @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): @@ -208,6 +220,8 @@ async def setAppConfig(req : SetAppConfigRequest): if 'net' not in config: config['net'] = {} config['net']['listen_port'] = int(req.listen_port) + if req.test_sd2 is not None: + config['test_sd2'] = req.test_sd2 try: setConfig(config) @@ -230,9 +244,9 @@ def is_malicious_model(file_path): return False except Exception as e: print('error while scanning', file_path, 'error:', e) - return False +known_models = {} def getModels(): models = { 'active': { @@ -255,9 +269,14 @@ def getModels(): if not file.endswith(model_extension): continue - if is_malicious_model(os.path.join(models_dir, file)): - models['scan-error'] = file - return + model_path = os.path.join(models_dir, file) + mtime = os.path.getmtime(model_path) + mod_time = known_models[model_path] if model_path in known_models else -1 + if mod_time != mtime: + if is_malicious_model(model_path): + models['scan-error'] = file + return + known_models[model_path] = mtime model_name = file[:-len(model_extension)] models['options'][model_type].append(model_name) @@ -435,6 +454,9 @@ class LogSuppressFilter(logging.Filter): return True logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) +# Check models and prepare cache for UI open +getModels() + # Start the task_manager task_manager.default_model_to_load = resolve_ckpt_to_use() task_manager.default_vae_to_load = resolve_vae_to_use()