Compare commits

...

14 Commits

Author SHA1 Message Date
50cce36d94 Initial version that works with the lstein fork. The only things not working are: CPU mode, streaming updates (live and progress bar), Turbo Mode, and keeps the model in VRAM instead of RAM 2022-09-29 20:27:46 +05:30
196649c0e9 Use the correct seed from the response 2022-09-29 13:55:09 +05:30
12182ee04f Newer images go on top 2022-09-29 13:52:48 +05:30
5db64526cc Fix a bug where batches would overwrite the previous images inside a task 2022-09-29 13:43:25 +05:30
5c2ec70eb4 Hide the sampler field when an output image is used as the new input 2022-09-29 13:12:01 +05:30
24a2c6251f Remove log statement 2022-09-29 13:08:58 +05:30
0d035d9ae9 Remove unnecessary semicolons 2022-09-29 13:08:42 +05:30
a28f1294e2 Integrate with beta; Use the outputContainer for the task; Don't show the info while a live preview is generating; Use the local task container reference instead of a seed-based identifier (will fail if the seed is same across two tasks) 2022-09-29 13:01:18 +05:30
a3b0cde59d Merge pull request #242 from Hakorr/main
Image item refactor
2022-09-29 12:03:45 +05:30
c2dec9eac4 Merge branch 'haka-fix' into main 2022-09-29 12:00:44 +05:30
10c4bee1e5 Fix for show all images 2022-09-28 00:05:34 +03:00
c1dea44fa6 Fix for live preview 2022-09-27 17:23:19 +03:00
5ba802dc68 Overlaid info 2022-09-26 17:50:27 +03:00
62048c68f0 Image item refactor and redesign 2022-09-25 02:55:11 +03:00
8 changed files with 478 additions and 729 deletions

View File

@ -15,16 +15,17 @@
@call git reset --hard @call git reset --hard
@call git pull @call git pull
@call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c @call git checkout d87bd29a6862996d8a0980c1343b6f0d4eb718b4
@call git apply ..\ui\sd_internal\ddim_callback.patch @REM @call git apply ..\ui\sd_internal\ddim_callback.patch
@call git apply ..\ui\sd_internal\env_yaml.patch @REM @call git apply ..\ui\sd_internal\env_yaml.patch
@call git apply ..\ui\sd_internal\custom_sd.patch
@cd .. @cd ..
) else ( ) else (
@echo. & echo "Downloading Stable Diffusion.." & echo. @echo. & echo "Downloading Stable Diffusion.." & echo.
@call git clone https://github.com/basujindal/stable-diffusion.git && ( @call git clone https://github.com/invoke-ai/InvokeAI.git stable-diffusion && (
@echo sd_git_cloned >> scripts\install_status.txt @echo sd_git_cloned >> scripts\install_status.txt
) || ( ) || (
@echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" @echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
@ -33,10 +34,11 @@
) )
@cd stable-diffusion @cd stable-diffusion
@call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c @call git checkout d87bd29a6862996d8a0980c1343b6f0d4eb718b4
@call git apply ..\ui\sd_internal\ddim_callback.patch @REM @call git apply ..\ui\sd_internal\ddim_callback.patch
@call git apply ..\ui\sd_internal\env_yaml.patch @REM @call git apply ..\ui\sd_internal\env_yaml.patch
@call git apply ..\ui\sd_internal\custom_sd.patch
@cd .. @cd ..
) )
@ -81,58 +83,6 @@
set PATH=C:\Windows\System32;%PATH% set PATH=C:\Windows\System32;%PATH%
@>nul grep -c "conda_sd_gfpgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for GFPGAN (Face Correction) were already installed"
) else (
@echo. & echo "Downloading packages necessary for GFPGAN (Face Correction).." & echo.
@set PYTHONNOUSERSITE=1
@call pip install -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN || (
@echo. & echo "Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@call pip install basicsr==1.4.2 || (
@echo. & echo "Error installing the basicsr package necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
for /f "tokens=*" %%a in ('python -c "from gfpgan import GFPGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_gfpgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul grep -c "conda_sd_esrgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
) else (
@echo. & echo "Downloading packages necessary for ESRGAN (Resolution Upscaling).." & echo.
@set PYTHONNOUSERSITE=1
@call pip install -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan || (
@echo. & echo "Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
for /f "tokens=*" %%a in ('python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_esrgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul grep -c "conda_sd_ui_deps_installed" ..\scripts\install_status.txt @>nul grep -c "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" ( @if "%ERRORLEVEL%" EQU "0" (
echo "Packages necessary for Stable Diffusion UI were already installed" echo "Packages necessary for Stable Diffusion UI were already installed"

View File

@ -4,7 +4,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16"> <link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16">
<link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32"> <link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32">
<link rel="stylesheet" href="/media/main.css?v=10"> <link rel="stylesheet" href="/media/main.css?v=21">
<link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1"> <link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css"> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
<link rel="stylesheet" href="/media/drawingboard.min.css"> <link rel="stylesheet" href="/media/drawingboard.min.css">
@ -15,7 +15,7 @@
<div id="container"> <div id="container">
<div id="top-nav"> <div id="top-nav">
<div id="logo"> <div id="logo">
<h1>Stable Diffusion UI <small>v2.195 <span id="updateBranchLabel"></span></small></h1> <h1>Stable Diffusion UI <small>v2.2 <span id="updateBranchLabel"></span></small></h1>
</div> </div>
<ul id="top-nav-items"> <ul id="top-nav-items">
<li class="dropdown"> <li class="dropdown">
@ -213,7 +213,7 @@
</div> </div>
</body> </body>
<script src="media/main.js?v=14"></script> <script src="media/main.js?v=21"></script>
<script> <script>
async function init() { async function init() {
await loadModifiers() await loadModifiers()

View File

@ -70,38 +70,36 @@ label {
font-size: 8pt; font-size: 8pt;
} }
.imgSeedLabel { .imgSeedLabel {
position: absolute; font-size: 0.8em;
transform: translateX(-100%); background-color: rgb(44, 45, 48);
margin-top: 5pt; border-radius: 3px;
margin-left: -5pt; padding: 5px;
font-size: 10pt;
background-color: #333;
opacity: 0.8;
color: #ddd;
border-radius: 3pt;
padding: 1pt 3pt;
}
.imgUseBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 30pt;
margin-left: -5pt;
}
.imgSaveBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 55pt;
margin-left: -5pt;
} }
.imgItem { .imgItem {
display: inline; display: inline-block;
padding-right: 10px; margin-top: 1em;
margin-right: 1em;
}
.imgContainer {
display: flex;
justify-content: flex-end;
} }
.imgItemInfo { .imgItemInfo {
opacity: 0.5; padding-bottom: 0.5em;
display: flex;
align-items: flex-end;
flex-direction: column;
position: absolute;
padding: 5px;
opacity: 0;
transition: 0.1s all;
}
.imgContainer:hover > .imgItemInfo {
opacity: 1;
}
.imgItemInfo * {
margin-bottom: 7px;
} }
#container { #container {
width: 90%; width: 90%;
margin-left: auto; margin-left: auto;
@ -409,4 +407,7 @@ img {
font-size: 10pt; font-size: 10pt;
color: #aaa; color: #aaa;
margin-bottom: 5pt; margin-bottom: 5pt;
}
.img-batch {
display: inline;
} }

View File

@ -73,10 +73,10 @@ let editorTagsContainer = document.querySelector('#editor-inputs-tags-container'
let imagePreview = document.querySelector("#preview") let imagePreview = document.querySelector("#preview")
let previewImageField = document.querySelector('#preview-image') let previewImageField = document.querySelector('#preview-image')
previewImageField.onchange = () => changePreviewImages(previewImageField.value); previewImageField.onchange = () => changePreviewImages(previewImageField.value)
let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider') let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider')
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value); modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value)
// let previewPrompt = document.querySelector('#preview-prompt') // let previewPrompt = document.querySelector('#preview-prompt')
@ -119,8 +119,8 @@ let bellPending = false
let taskQueue = [] let taskQueue = []
let currentTask = null let currentTask = null
const modifierThumbnailPath = 'media/modifier-thumbnails'; const modifierThumbnailPath = 'media/modifier-thumbnails'
const activeCardClass = 'modifier-card-active'; const activeCardClass = 'modifier-card-active'
function getLocalStorageItem(key, fallback) { function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key) let item = localStorage.getItem(key)
@ -202,7 +202,7 @@ function isStreamImageProgressEnabled() {
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') { if (statusType !== 'server') {
return; return
} }
if (msgType == 'error') { if (msgType == 'error') {
@ -259,18 +259,100 @@ async function healthCheck() {
} }
} }
function makeImageElement(width, height, outputContainer) { function showImages(req, res, outputContainer, livePreview) {
let imgItem = document.createElement('div') let imageItemElements = outputContainer.querySelectorAll('.imgItem')
imgItem.className = 'imgItem'
let img = document.createElement('img') res.output.reverse()
img.width = parseInt(width)
img.height = parseInt(height)
imgItem.appendChild(img) res.output.forEach((result, index) => {
outputContainer.insertBefore(imgItem, outputContainer.firstChild) if(typeof res != 'object') return
return imgItem const imageData = result?.data || result?.path + '?t=' + new Date().getTime(),
imageSeed = result?.seed,
imageWidth = req.width,
imageHeight = req.height;
if (!imageData.includes('/')) {
// res contained no data for the image, stop execution
setStatus('request', 'invalid image', 'error')
return
}
let imageItemElem = (index < imageItemElements.length ? imageItemElements[index] : null)
if(!imageItemElem) {
imageItemElem = document.createElement('div')
imageItemElem.className = 'imgItem'
imageItemElem.innerHTML = `
<div class="imgContainer">
<img/>
<div class="imgItemInfo">
<span class="imgSeedLabel"></span>
<button class="imgUseBtn">Use as Input</button>
<button class="imgSaveBtn">Download</button>
</div>
</div>
`
const useAsInputBtn = imageItemElem.querySelector('.imgUseBtn'),
saveImageBtn = imageItemElem.querySelector('.imgSaveBtn');
useAsInputBtn.addEventListener('click', getUseAsInputHandler(imageItemElem))
saveImageBtn.addEventListener('click', getSaveImageHandler(imageItemElem))
outputContainer.appendChild(imageItemElem)
}
const imageElem = imageItemElem.querySelector('img'),
imageSeedLabel = imageItemElem.querySelector('.imgSeedLabel');
imageElem.src = imageData
imageElem.width = parseInt(imageWidth)
imageElem.height = parseInt(imageHeight)
imageElem.setAttribute('data-seed', imageSeed)
const imageInfo = imageItemElem.querySelector('.imgItemInfo')
imageInfo.style.visibility = (livePreview ? 'hidden' : 'visible')
imageSeedLabel.innerText = 'Seed: ' + imageSeed
})
}
function getUseAsInputHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
initImageSelector.value = null
initImagePreview.src = imgData
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
samplerSelectionContainer.style.display = 'none'
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = imageSeed
seedField.disabled = false
}
}
function getSaveImageHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
const imgDownload = document.createElement('a')
imgDownload.download = createFileName(imageSeed)
imgDownload.href = imgData
imgDownload.click()
}
} }
// makes a single image. don't call this directly, use makeImage() instead // makes a single image. don't call this directly, use makeImage() instead
@ -281,7 +363,10 @@ async function doMakeImage(task) {
const reqBody = task.reqBody const reqBody = task.reqBody
const batchCount = task.batchCount const batchCount = task.batchCount
const outputContainer = task.outputContainer const outputContainer = document.createElement('div')
outputContainer.className = 'img-batch'
task.outputContainer.insertBefore(outputContainer, task.outputContainer.firstChild)
const outputMsg = task['outputMsg'] const outputMsg = task['outputMsg']
const previewPrompt = task['previewPrompt'] const previewPrompt = task['previewPrompt']
@ -291,14 +376,6 @@ async function doMakeImage(task) {
let seed = reqBody['seed'] let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs']) let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers(numImages) {
for (let i = images.length; i < numImages; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height, outputContainer))
}
}
try { try {
res = await fetch('/image', { res = await fetch('/image', {
method: 'POST', method: 'POST',
@ -351,14 +428,7 @@ async function doMakeImage(task) {
outputMsg.style.display = 'block' outputMsg.style.display = 'block'
if (stepUpdate.output !== undefined) { if (stepUpdate.output !== undefined) {
makeImageContainers(numOutputs) showImages(reqBody, stepUpdate, outputContainer, true)
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
} }
} }
} catch (e) { } catch (e) {
@ -426,85 +496,11 @@ async function doMakeImage(task) {
res = undefined res = undefined
} }
if (!res) { if (!res) return false
return false
}
lastPromptUsed = reqBody['prompt'] lastPromptUsed = reqBody['prompt']
makeImageContainers(res.output.length) showImages(reqBody, res, outputContainer, false)
for (let idx in res.output) {
let imgBody = ''
let seed = 0
try {
let imgData = res.output[idx]
imgBody = imgData.data
seed = imgData.seed
} catch (e) {
console.log(imgBody)
setStatus('request', 'invalid image', 'error')
continue
}
let imgItem = images[idx]
let img = imgItem.firstChild
img.src = imgBody
let imgItemInfo = document.createElement('span')
imgItemInfo.className = 'imgItemInfo'
imgItemInfo.style.opacity = 0
let imgSeedLabel = document.createElement('span')
imgSeedLabel.className = 'imgSeedLabel'
imgSeedLabel.innerText = 'Seed: ' + seed
let imgUseBtn = document.createElement('button')
imgUseBtn.className = 'imgUseBtn'
imgUseBtn.innerText = 'Use as Input'
let imgSaveBtn = document.createElement('button')
imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerText = 'Download'
imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = imgBody
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
seedField.disabled = false
})
imgSaveBtn.addEventListener('click', function() {
let imgDownload = document.createElement('a')
imgDownload.download = createFileName();
imgDownload.href = imgBody
imgDownload.click()
})
imgItem.addEventListener('mouseenter', function() {
imgItemInfo.style.opacity = 1
})
imgItem.addEventListener('mouseleave', function() {
imgItemInfo.style.opacity = 0
})
}
return true return true
} }
@ -547,7 +543,6 @@ async function checkTasks() {
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop' task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing" task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].className += " activeTaskLabel" task['taskStatusLabel'].className += " activeTaskLabel"
console.log(task['taskStatusLabel'].className)
for (let i = 0; i < task.batchCount; i++) { for (let i = 0; i < task.batchCount; i++) {
task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs']) task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs'])
@ -576,7 +571,9 @@ async function checkTasks() {
// setStatus('request', 'done', 'success') // setStatus('request', 'done', 'success')
} else { } else {
task.outputMsg.innerText = 'Task ended after ' + time + ' seconds' if (task.outputMsg.innerText.toLowerCase().indexOf('error') === -1) {
task.outputMsg.innerText = 'Task ended after ' + time + ' seconds'
}
} }
if (randomSeedField.checked) { if (randomSeedField.checked) {
@ -610,8 +607,8 @@ async function makeImage() {
let prompt = promptField.value let prompt = promptField.value
if (activeTags.length > 0) { if (activeTags.length > 0) {
let promptTags = activeTags.map(x => x.name).join(", "); let promptTags = activeTags.map(x => x.name).join(", ")
prompt += ", " + promptTags; prompt += ", " + promptTags
} }
let reqBody = { let reqBody = {
@ -734,12 +731,11 @@ async function makeImage() {
// create a file name with embedded prompt and metadata // create a file name with embedded prompt and metadata
// for easier cateloging and comparison // for easier cateloging and comparison
function createFileName() { function createFileName(seed) {
// Most important information is the prompt // Most important information is the prompt
let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_') let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_')
underscoreName = underscoreName.substring(0, 100) underscoreName = underscoreName.substring(0, 100)
const seed = seedField.value
const steps = numInferenceStepsField.value const steps = numInferenceStepsField.value
const guidance = guidanceScaleField.value const guidance = guidanceScaleField.value
@ -747,20 +743,20 @@ function createFileName() {
let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}` let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}`
// add the tags // add the tags
// let tags = []; // let tags = []
// let tagString = ''; // let tagString = ''
// document.querySelectorAll(modifyTagsSelector).forEach(function(tag) { // document.querySelectorAll(modifyTagsSelector).forEach(function(tag) {
// tags.push(tag.innerHTML); // tags.push(tag.innerHTML)
// }) // })
// join the tags with a pipe // join the tags with a pipe
// if (activeTags.length > 0) { // if (activeTags.length > 0) {
// tagString = '_Tags-'; // tagString = '_Tags-'
// tagString += tags.join('|'); // tagString += tags.join('|')
// } // }
// // append empty or populated tags // // append empty or populated tags
// fileName += `${tagString}`; // fileName += `${tagString}`
// add the file extension // add the file extension
fileName += `.png` fileName += `.png`
@ -1035,25 +1031,25 @@ maskSetting.addEventListener('click', function() {
// https://stackoverflow.com/a/8212878 // https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) { function millisecondsToStr(milliseconds) {
function numberEnding (number) { function numberEnding (number) {
return (number > 1) ? 's' : ''; return (number > 1) ? 's' : ''
} }
var temp = Math.floor(milliseconds / 1000); var temp = Math.floor(milliseconds / 1000)
var hours = Math.floor((temp %= 86400) / 3600); var hours = Math.floor((temp %= 86400) / 3600)
var s = '' var s = ''
if (hours) { if (hours) {
s += hours + ' hour' + numberEnding(hours) + ' '; s += hours + ' hour' + numberEnding(hours) + ' '
} }
var minutes = Math.floor((temp %= 3600) / 60); var minutes = Math.floor((temp %= 3600) / 60)
if (minutes) { if (minutes) {
s += minutes + ' minute' + numberEnding(minutes) + ' '; s += minutes + ' minute' + numberEnding(minutes) + ' '
} }
var seconds = temp % 60; var seconds = temp % 60
if (!hours && minutes < 4 && seconds) { if (!hours && minutes < 4 && seconds) {
s += seconds + ' second' + numberEnding(seconds); s += seconds + ' second' + numberEnding(seconds)
} }
return s; return s
} }
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/ // https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
@ -1113,33 +1109,33 @@ function createCollapsibles(node) {
createCollapsibles() createCollapsibles()
function refreshTagsList() { function refreshTagsList() {
editorModifierTagsList.innerHTML = ''; editorModifierTagsList.innerHTML = ''
if (activeTags.length == 0) { if (activeTags.length == 0) {
editorTagsContainer.style.display = 'none'; editorTagsContainer.style.display = 'none'
return; return
} else { } else {
editorTagsContainer.style.display = 'block'; editorTagsContainer.style.display = 'block'
} }
activeTags.forEach((tag, index) => { activeTags.forEach((tag, index) => {
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'; tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'
tag.element.classList.add('modifier-card-tiny'); tag.element.classList.add('modifier-card-tiny')
editorModifierTagsList.appendChild(tag.element); editorModifierTagsList.appendChild(tag.element)
tag.element.addEventListener('click', () => { tag.element.addEventListener('click', () => {
let idx = activeTags.indexOf(tag); let idx = activeTags.indexOf(tag)
if (idx !== -1) { if (idx !== -1) {
activeTags[idx].originElement.classList.remove(activeCardClass); activeTags[idx].originElement.classList.remove(activeCardClass)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'; activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
activeTags.splice(idx, 1); activeTags.splice(idx, 1)
refreshTagsList(); refreshTagsList()
} }
}); })
}); })
let brk = document.createElement('br') let brk = document.createElement('br')
brk.style.clear = 'both' brk.style.clear = 'both'
@ -1168,8 +1164,8 @@ async function getDiskPath() {
} }
function createModifierCard(name, previews) { function createModifierCard(name, previews) {
const modifierCard = document.createElement('div'); const modifierCard = document.createElement('div')
modifierCard.className = 'modifier-card'; modifierCard.className = 'modifier-card'
modifierCard.innerHTML = ` modifierCard.innerHTML = `
<div class="modifier-card-overlay"></div> <div class="modifier-card-overlay"></div>
<div class="modifier-card-image-container"> <div class="modifier-card-image-container">
@ -1179,96 +1175,96 @@ function createModifierCard(name, previews) {
</div> </div>
<div class="modifier-card-container"> <div class="modifier-card-container">
<div class="modifier-card-label"><p></p></div> <div class="modifier-card-label"><p></p></div>
</div>`; </div>`
const image = modifierCard.querySelector('.modifier-card-image'); const image = modifierCard.querySelector('.modifier-card-image')
const errorText = modifierCard.querySelector('.modifier-card-error-label'); const errorText = modifierCard.querySelector('.modifier-card-error-label')
const label = modifierCard.querySelector('.modifier-card-label'); const label = modifierCard.querySelector('.modifier-card-label')
errorText.innerText = 'No Image'; errorText.innerText = 'No Image'
if (typeof previews == 'object') { if (typeof previews == 'object') {
image.src = previews[0]; // portrait image.src = previews[0]; // portrait
image.setAttribute('preview-type', 'portrait'); image.setAttribute('preview-type', 'portrait')
} else { } else {
image.remove(); image.remove()
} }
const maxLabelLength = 30; const maxLabelLength = 30
const nameWithoutBy = name.replace('by ', ''); const nameWithoutBy = name.replace('by ', '')
if(nameWithoutBy.length <= maxLabelLength) { if(nameWithoutBy.length <= maxLabelLength) {
label.querySelector('p').innerText = nameWithoutBy; label.querySelector('p').innerText = nameWithoutBy
} else { } else {
const tooltipText = document.createElement('span'); const tooltipText = document.createElement('span')
tooltipText.className = 'tooltip-text'; tooltipText.className = 'tooltip-text'
tooltipText.innerText = name; tooltipText.innerText = name
label.classList.add('tooltip'); label.classList.add('tooltip')
label.appendChild(tooltipText); label.appendChild(tooltipText)
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'; label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'
} }
return modifierCard; return modifierCard
} }
function changePreviewImages(val) { function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img'); const previewImages = document.querySelectorAll('.modifier-card-image-container img')
let previewArr = []; let previewArr = []
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews))); modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)))
previewArr = previewArr.map(x => { previewArr = previewArr.map(x => {
let obj = {}; let obj = {}
x.forEach(preview => { x.forEach(preview => {
obj[preview.name] = preview.path; obj[preview.name] = preview.path
}); })
return obj; return obj
}); })
previewImages.forEach(previewImage => { previewImages.forEach(previewImage => {
const currentPreviewType = previewImage.getAttribute('preview-type'); const currentPreviewType = previewImage.getAttribute('preview-type')
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop(); const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop()
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType]); const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType])
if(typeof previews == 'object') { if(typeof previews == 'object') {
let preview = null; let preview = null
if (val == 'portrait') { if (val == 'portrait') {
preview = previews.portrait; preview = previews.portrait
} }
else if (val == 'landscape') { else if (val == 'landscape') {
preview = previews.landscape; preview = previews.landscape
} }
if(preview != null) { if(preview != null) {
previewImage.src = `${modifierThumbnailPath}/${preview}`; previewImage.src = `${modifierThumbnailPath}/${preview}`
previewImage.setAttribute('preview-type', val); previewImage.setAttribute('preview-type', val)
} }
} }
}); })
} }
function resizeModifierCards(val) { function resizeModifierCards(val) {
const cardSizePrefix = 'modifier-card-size_'; const cardSizePrefix = 'modifier-card-size_'
const modifierCardClass = 'modifier-card'; const modifierCardClass = 'modifier-card'
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`); const modifierCards = document.querySelectorAll(`.${modifierCardClass}`)
const cardSize = n => `${cardSizePrefix}${n}`; const cardSize = n => `${cardSizePrefix}${n}`
modifierCards.forEach(card => { modifierCards.forEach(card => {
// remove existing size classes // remove existing size classes
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix)); const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
card.className = classes.join(' ').trim(); card.className = classes.join(' ').trim()
if(val != 0) if(val != 0)
card.classList.add(cardSize(val)); card.classList.add(cardSize(val))
}); })
} }
async function loadModifiers() { async function loadModifiers() {
@ -1280,15 +1276,15 @@ async function loadModifiers() {
modifiers = res; // update global variable modifiers = res; // update global variable
res.forEach((modifierGroup, idx) => { res.forEach((modifierGroup, idx) => {
const title = modifierGroup.category; const title = modifierGroup.category
const modifiers = modifierGroup.modifiers; const modifiers = modifierGroup.modifiers
const titleEl = document.createElement('h5'); const titleEl = document.createElement('h5')
titleEl.className = 'collapsible'; titleEl.className = 'collapsible'
titleEl.innerText = title; titleEl.innerText = title
const modifiersEl = document.createElement('div'); const modifiersEl = document.createElement('div')
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf'); modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf')
if (idx == 0) { if (idx == 0) {
titleEl.className += ' active' titleEl.className += ' active'
@ -1296,21 +1292,21 @@ async function loadModifiers() {
} }
modifiers.forEach(modObj => { modifiers.forEach(modObj => {
const modifierName = modObj.modifier; const modifierName = modObj.modifier
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`); const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`)
const modifierCard = createModifierCard(modifierName, modifierPreviews); const modifierCard = createModifierCard(modifierName, modifierPreviews)
if(typeof modifierCard == 'object') { if(typeof modifierCard == 'object') {
modifiersEl.appendChild(modifierCard); modifiersEl.appendChild(modifierCard)
modifierCard.addEventListener('click', () => { modifierCard.addEventListener('click', () => {
if (activeTags.map(x => x.name).includes(modifierName)) { if (activeTags.map(x => x.name).includes(modifierName)) {
// remove modifier from active array // remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName); activeTags = activeTags.filter(x => x.name != modifierName)
modifierCard.classList.remove(activeCardClass); modifierCard.classList.remove(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'; modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
} else { } else {
// add modifier to active array // add modifier to active array
activeTags.push({ activeTags.push({
@ -1318,17 +1314,17 @@ async function loadModifiers() {
'element': modifierCard.cloneNode(true), 'element': modifierCard.cloneNode(true),
'originElement': modifierCard, 'originElement': modifierCard,
'previews': modifierPreviews 'previews': modifierPreviews
}); })
modifierCard.classList.add(activeCardClass); modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'; modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
} }
refreshTagsList(); refreshTagsList()
}); })
} }
}); })
let brk = document.createElement('br') let brk = document.createElement('br')
brk.style.clear = 'both' brk.style.clear = 'both'
@ -1346,4 +1342,4 @@ async function loadModifiers() {
} catch (e) { } catch (e) {
console.log('error fetching modifiers', e) console.log('error fetching modifiers', e)
} }
} }

View File

@ -23,6 +23,7 @@ class Request:
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # "png", "jpeg"
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
@ -42,6 +43,7 @@ class Request:
"sampler": self.sampler, "sampler": self.sampler,
"use_face_correction": self.use_face_correction, "use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale, "use_upscale": self.use_upscale,
"output_format": self.output_format,
} }
def to_string(self): def to_string(self):
@ -63,6 +65,7 @@ class Request:
use_face_correction: {self.use_face_correction} use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale} use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image} show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
stream_progress_updates: {self.stream_progress_updates} stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}''' stream_image_progress: {self.stream_image_progress}'''

View File

@ -0,0 +1,46 @@
diff --git a/ldm/dream/conditioning.py b/ldm/dream/conditioning.py
index dfa1089..e4908ad 100644
--- a/ldm/dream/conditioning.py
+++ b/ldm/dream/conditioning.py
@@ -12,8 +12,8 @@ log_tokenization() print out colour-coded tokens and warn if trunca
import re
import torch
-def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
- uc = model.get_learned_conditioning([''])
+def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False, negative_prompt=''):
+ uc = model.get_learned_conditioning([negative_prompt])
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
diff --git a/ldm/generate.py b/ldm/generate.py
index 8f67403..d88ce2d 100644
--- a/ldm/generate.py
+++ b/ldm/generate.py
@@ -205,6 +205,7 @@ class Generate:
init_mask = None,
fit = False,
strength = None,
+ init_img_is_path = True,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
@@ -303,11 +304,15 @@ class Generate:
uc, c = get_uc_and_c(
prompt, model=self.model,
skip_normalize=skip_normalize,
- log_tokens=self.log_tokenization
+ log_tokens=self.log_tokenization,
+ negative_prompt=(args['negative_prompt'] if 'negative_prompt' in args else '')
)
- (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
-
+ if init_img_is_path:
+ (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
+ else:
+ (init_image,mask_image) = (init_img, init_mask)
+
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif init_image is not None:

View File

@ -1,64 +1,47 @@
import json import sys
import os, re import os
import traceback import uuid
import re
import torch import torch
import traceback
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageOps from pytorch_lightning import logging
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange from einops import rearrange
import time from PIL import Image, ImageOps, ImageChops
from pytorch_lightning import seed_everything from ldm.generate import Generate
from torch import autocast import transformers
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from optimizedSD.optimUtils import split_weighted_subprompts
from transformers import logging
from gfpgan import GFPGANer from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import uuid transformers.logging.set_verbosity_error()
logging.set_verbosity_error()
# consts
config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]')
# api stuff
from . import Request, Response, Image as ResponseImage from . import Request, Response, Image as ResponseImage
import base64 import base64
import json
from io import BytesIO from io import BytesIO
#from colorama import Fore
filename_regex = re.compile('[^a-zA-Z0-9]')
generator = None
gfpgan_file = None
real_esrgan_file = None
model_gfpgan = None
model_real_esrgan = None
device = None
precision = 'autocast'
has_valid_gpu = False
force_full_precision = False
# local # local
stop_processing = False stop_processing = False
temp_images = {} temp_images = {}
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
model = None
modelCS = None
modelFS = None
model_gfpgan = None
model_real_esrgan = None
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
has_valid_gpu = False
force_full_precision = False
try: try:
gpu = torch.cuda.current_device() gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu) gpu_name = torch.cuda.get_device_name(gpu)
@ -79,68 +62,45 @@ except:
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!') print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass pass
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False): def load_model_ckpt(ckpt_to_use, device_to_use='cuda', precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half global generator
ckpt_file = ckpt_to_use
device = device_to_use if has_valid_gpu else 'cpu' device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full' precision = precision_to_use if not force_full_precision else 'full'
unet_bs = unet_bs_to_use
if device == 'cpu': try:
precision = 'full' config = 'configs/models.yaml'
model = 'stable-diffusion-1.4'
sd = load_model_from_config(f"{ckpt_file}.ckpt") models = OmegaConf.load(config)
li, lo = [], [] width = models[model].width
for key, value in sd.items(): height = models[model].height
sp = key.split(".") config = models[model].config
if (sp[0]) == "model": weights = ckpt_to_use + '.ckpt'
if "input_blocks" in sp: except (FileNotFoundError, IOError, KeyError) as e:
li.append(key) print(f'{e}. Aborting.')
elif "middle_block" in sp: sys.exit(-1)
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)
config = OmegaConf.load(f"{config_yaml}") generator = Generate(
width=width,
height=height,
sampler_name='ddim',
weights=weights,
full_precision=(precision == 'full'),
config=config,
grid=False,
# this is solely for recreating the prompt
seamless=False,
embedding_path=None,
device_type=device,
ignore_ctrl_c=True,
)
model = instantiate_from_config(config.modelUNet) # gets rid of annoying messages about random seed
_, _ = model.load_state_dict(sd, strict=False) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = turbo
modelCS = instantiate_from_config(config.modelCondStage) # preload the model
_, _ = modelCS.load_state_dict(sd, strict=False) generator.load_model()
modelCS.eval()
modelCS.cond_stage_model.device = device
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
model_is_half = True
else:
model_is_half = False
if half_model_fs:
modelFS.half()
model_fs_is_half = True
else:
model_fs_is_half = False
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
def load_model_gfpgan(gfpgan_to_use): def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan global gfpgan_file, model_gfpgan
@ -179,7 +139,7 @@ def load_model_real_esrgan(real_esrgan_to_use):
model_real_esrgan.device = torch.device('cpu') model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu') model_real_esrgan.model.to('cpu')
else: else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half) model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=(precision != 'full'))
model_real_esrgan.model.name = real_esrgan_to_use model_real_esrgan.model.name = real_esrgan_to_use
@ -193,14 +153,14 @@ def mk_img(req: Request):
gc() gc()
if device != "cpu": # if device != "cpu":
modelFS.to("cpu") # modelFS.to("cpu")
modelCS.to("cpu") # modelCS.to("cpu")
model.model1.to("cpu") # model.model1.to("cpu")
model.model2.to("cpu") # model.model2.to("cpu")
gc() # gc()
yield json.dumps({ yield json.dumps({
"status": 'failed', "status": 'failed',
@ -208,292 +168,164 @@ def mk_img(req: Request):
}) })
def do_mk_img(req: Request): def do_mk_img(req: Request):
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False stop_processing = False
res = Response()
res.request = req
res.images = []
temp_images.clear()
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
device = 'cpu'
if model_is_half:
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device)
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \
(req.init_image is None and model_fs_is_half) or \
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if req.use_face_correction != gfpgan_file: if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction) load_model_gfpgan(req.use_face_correction)
if req.use_upscale != real_esrgan_file: if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale) load_model_real_esrgan(req.use_upscale)
model.cdevice = device init_image = None
modelCS.cond_stage_model.device = device init_mask = None
opt_prompt = req.prompt if req.init_image is not None:
opt_seed = req.seed image = base64_str_to_img(req.init_image)
opt_n_samples = req.num_outputs
opt_n_iter = 1
opt_scale = req.guidance_scale
opt_C = 4
opt_H = req.height
opt_W = req.width
opt_f = 8
opt_ddim_steps = req.num_inference_steps
opt_ddim_eta = 0.0
opt_strength = req.prompt_strength
opt_save_to_disk_path = req.save_to_disk_path
opt_init_img = req.init_image
opt_use_face_correction = req.use_face_correction
opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png'
opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device) w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
print('\n\n Using precision:', precision) w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
init_image = generator._create_init_image(image)
seed_everything(opt_seed) if generator._has_transparency(image) and req.mask is None: # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.')
if generator._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
init_mask = generator._create_init_mask(image) # this returns a torch tensor
batch_size = opt_n_samples if device != "cpu" and precision != "full":
prompt = opt_prompt
assert prompt is not None
data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
mask = None
if req.init_image is None:
handler = _txt2img
init_latent = None
t_enc = None
else:
handler = _img2img
init_image = load_img(req.init_image, opt_W, opt_H)
init_image = init_image.to(device)
if device != "cpu" and precision == "autocast":
init_image = init_image.half() init_image = init_image.half()
modelFS.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device) image = base64_str_to_img(req.mask)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast": image = ImageChops.invert(image)
mask = mask.half()
move_fs_to_cpu() w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
t_enc = int(opt_strength * opt_ddim_steps) image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
print(f"target t_enc is {t_enc} steps")
if opt_save_to_disk_path is not None: init_mask = generator._create_init_mask(image)
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
if init_mask is not None:
req.sampler = 'plms' # hack to force the underlying implementation to initialize DDIM properly
result = generator.prompt2image(
req.prompt,
iterations = req.num_outputs,
steps = req.num_inference_steps,
seed = req.seed,
cfg_scale = req.guidance_scale,
ddim_eta = 0.0,
skip_normalize = False,
image_callback = None,
step_callback = None,
width = req.width,
height = req.height,
sampler_name = req.sampler,
seamless = False,
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
# these are specific to img2img and inpaint
init_img = init_image,
init_mask = init_mask,
fit = False,
strength = req.prompt_strength,
init_img_is_path = False,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
upscale = None,
negative_prompt= req.negative_prompt,
)
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
print('has filter', has_filters)
return_orig_img = not has_filters or not req.show_only_filtered_image
res = Response()
res.request = req
res.images = []
if req.save_to_disk_path is not None:
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True) os.makedirs(session_out_path, exist_ok=True)
else: else:
session_out_path = None session_out_path = None
seeds = "" for img, seed in result:
with torch.no_grad(): if req.save_to_disk_path is not None:
for n in trange(opt_n_iter, desc="Sampling"): prompt_flattened = filename_regex.sub('_', req.prompt)
for prompts in tqdm(data, desc="data"): prompt_flattened = prompt_flattened[:50]
with precision_scope("cuda"): img_id = str(uuid.uuid4())[-8:]
modelCS.to(device)
uc = None
if opt_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
subprompts, weights = split_weighted_subprompts(prompts[0]) file_path = f"{prompt_flattened}_{img_id}"
if len(subprompts) > 1: img_out_path = os.path.join(session_out_path, f"{file_path}.{req.output_format}")
c = torch.zeros_like(uc) meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
modelFS.to(device) if return_orig_img:
save_image(img, img_out_path)
partial_x_samples = None save_metadata(meta_out_path, req.prompt, seed, req.width, req.height, req.num_inference_steps, req.guidance_scale, req.prompt_strength, req.use_face_correction, req.use_upscale, req.sampler, req.negative_prompt)
def img_callback(x_samples, i):
nonlocal partial_x_samples
partial_x_samples = x_samples if return_orig_img:
img_data = img_to_base64_str(img)
res_image_orig = ResponseImage(data=img_data, seed=seed)
res.images.append(res_image_orig)
if req.stream_progress_updates: if req.save_to_disk_path is not None:
n_steps = opt_ddim_steps if req.init_image is None else t_enc res_image_orig.path_abs = img_out_path
progress = {"step": i, "total_steps": n_steps}
if req.stream_image_progress and i % 5 == 0: if has_filters and not stop_processing:
partial_images = [] print('Applying filters..')
for i in range(batch_size): gc()
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) filters_applied = []
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_samples_ddim np_img = img.convert('RGB')
# don't delete x_samples, it is used in the code that called this callback np_img = np.array(np_img, dtype=np.uint8)
temp_images[str(req.session_id) + '/' + str(i)] = buf if req.use_face_correction:
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) _, _, np_img = model_gfpgan.enhance(np_img, has_aligned=False, only_center_face=False, paste_back=True)
filters_applied.append(req.use_face_correction)
progress['output'] = partial_images if req.use_upscale:
np_img, _ = model_real_esrgan.enhance(np_img)
filters_applied.append(req.use_upscale)
yield json.dumps(progress) filtered_image = Image.fromarray(np_img)
if stop_processing: filtered_img_data = img_to_base64_str(filtered_image)
raise UserInitiatedStop("User requested that we stop processing") res_image_filtered = ResponseImage(data=filtered_img_data, seed=seed)
res.images.append(res_image_filtered)
# run the handler filters_applied = "_".join(filters_applied)
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
yield from x_samples if req.save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{req.output_format}")
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
x_samples = partial_x_samples del filtered_image
except UserInitiatedStop:
if partial_x_samples is None: del img
continue
x_samples = partial_x_samples
print("saving images")
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not opt_show_only_filtered
if stop_processing:
return_orig_img = True
if opt_save_to_disk_path is not None:
prompt_flattened = filename_regex.sub('_', prompts[0])
prompt_flattened = prompt_flattened[:50]
img_id = str(uuid.uuid4())[-8:]
file_path = f"{prompt_flattened}_{img_id}"
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
if return_orig_img:
save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt)
if return_orig_img:
img_data = img_to_base64_str(img)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig)
if opt_save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
del img
if has_filters and not stop_processing:
print('Applying filters..')
gc()
filters_applied = []
if opt_use_face_correction:
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
x_sample = output[:,:,::-1]
filters_applied.append(opt_use_face_correction)
if opt_use_upscale:
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
x_sample = output[:,:,::-1]
filters_applied.append(opt_use_upscale)
filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(res_image_filtered)
filters_applied = "_".join(filters_applied)
if opt_save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
del filtered_image
seeds += str(opt_seed) + ","
opt_seed += 1
move_fs_to_cpu()
gc()
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print('Task completed') print('Task completed')
@ -505,8 +337,8 @@ def save_image(img, img_out_path):
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt): def save_metadata(meta_out_path, prompt, seed, width, height, num_inference_steps, guidance_scale, prompt_strength, use_correct_face, use_upscale, sampler_name, negative_prompt):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}" metadata = f"{prompt}\nWidth: {width}\nHeight: {height}\nSeed: {seed}\nSteps: {num_inference_steps}\nGuidance Scale: {guidance_scale}\nPrompt Strength: {prompt_strength}\nUse Face Correction: {use_correct_face}\nUse Upscaling: {use_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
try: try:
with open(meta_out_path, 'w') as f: with open(meta_out_path, 'w') as f:
@ -514,68 +346,6 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
def gc(): def gc():
if device == 'cpu': if device == 'cpu':
return return
@ -583,25 +353,6 @@ def gc():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
# internal
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
# utils
class UserInitiatedStop(Exception):
pass
def load_img(img_str, w0, h0): def load_img(img_str, w0, h0):
image = base64_str_to_img(img_str).convert("RGB") image = base64_str_to_img(img_str).convert("RGB")
w, h = image.size w, h = image.size

View File

@ -58,6 +58,7 @@ class ImageRequest(BaseModel):
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # "png", "jpeg"
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
@ -123,6 +124,7 @@ def image(req : ImageRequest):
r.use_upscale: str = req.use_upscale r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress r.stream_image_progress = req.stream_image_progress