Compare commits

...

14 Commits

Author SHA1 Message Date
50cce36d94 Initial version that works with the lstein fork. The only things not working are: CPU mode, streaming updates (live and progress bar), Turbo Mode, and keeps the model in VRAM instead of RAM 2022-09-29 20:27:46 +05:30
196649c0e9 Use the correct seed from the response 2022-09-29 13:55:09 +05:30
12182ee04f Newer images go on top 2022-09-29 13:52:48 +05:30
5db64526cc Fix a bug where batches would overwrite the previous images inside a task 2022-09-29 13:43:25 +05:30
5c2ec70eb4 Hide the sampler field when an output image is used as the new input 2022-09-29 13:12:01 +05:30
24a2c6251f Remove log statement 2022-09-29 13:08:58 +05:30
0d035d9ae9 Remove unnecessary semicolons 2022-09-29 13:08:42 +05:30
a28f1294e2 Integrate with beta; Use the outputContainer for the task; Don't show the info while a live preview is generating; Use the local task container reference instead of a seed-based identifier (will fail if the seed is same across two tasks) 2022-09-29 13:01:18 +05:30
a3b0cde59d Merge pull request #242 from Hakorr/main
Image item refactor
2022-09-29 12:03:45 +05:30
c2dec9eac4 Merge branch 'haka-fix' into main 2022-09-29 12:00:44 +05:30
10c4bee1e5 Fix for show all images 2022-09-28 00:05:34 +03:00
c1dea44fa6 Fix for live preview 2022-09-27 17:23:19 +03:00
5ba802dc68 Overlaid info 2022-09-26 17:50:27 +03:00
62048c68f0 Image item refactor and redesign 2022-09-25 02:55:11 +03:00
8 changed files with 478 additions and 729 deletions

View File

@ -15,16 +15,17 @@
@call git reset --hard
@call git pull
@call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c
@call git checkout d87bd29a6862996d8a0980c1343b6f0d4eb718b4
@call git apply ..\ui\sd_internal\ddim_callback.patch
@call git apply ..\ui\sd_internal\env_yaml.patch
@REM @call git apply ..\ui\sd_internal\ddim_callback.patch
@REM @call git apply ..\ui\sd_internal\env_yaml.patch
@call git apply ..\ui\sd_internal\custom_sd.patch
@cd ..
) else (
@echo. & echo "Downloading Stable Diffusion.." & echo.
@call git clone https://github.com/basujindal/stable-diffusion.git && (
@call git clone https://github.com/invoke-ai/InvokeAI.git stable-diffusion && (
@echo sd_git_cloned >> scripts\install_status.txt
) || (
@echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
@ -33,10 +34,11 @@
)
@cd stable-diffusion
@call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c
@call git checkout d87bd29a6862996d8a0980c1343b6f0d4eb718b4
@call git apply ..\ui\sd_internal\ddim_callback.patch
@call git apply ..\ui\sd_internal\env_yaml.patch
@REM @call git apply ..\ui\sd_internal\ddim_callback.patch
@REM @call git apply ..\ui\sd_internal\env_yaml.patch
@call git apply ..\ui\sd_internal\custom_sd.patch
@cd ..
)
@ -81,58 +83,6 @@
set PATH=C:\Windows\System32;%PATH%
@>nul grep -c "conda_sd_gfpgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for GFPGAN (Face Correction) were already installed"
) else (
@echo. & echo "Downloading packages necessary for GFPGAN (Face Correction).." & echo.
@set PYTHONNOUSERSITE=1
@call pip install -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN || (
@echo. & echo "Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@call pip install basicsr==1.4.2 || (
@echo. & echo "Error installing the basicsr package necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
for /f "tokens=*" %%a in ('python -c "from gfpgan import GFPGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_gfpgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul grep -c "conda_sd_esrgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
) else (
@echo. & echo "Downloading packages necessary for ESRGAN (Resolution Upscaling).." & echo.
@set PYTHONNOUSERSITE=1
@call pip install -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan || (
@echo. & echo "Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
for /f "tokens=*" %%a in ('python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/blob/main/Troubleshooting.md" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_esrgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul grep -c "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
echo "Packages necessary for Stable Diffusion UI were already installed"

View File

@ -4,7 +4,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16">
<link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32">
<link rel="stylesheet" href="/media/main.css?v=10">
<link rel="stylesheet" href="/media/main.css?v=21">
<link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
<link rel="stylesheet" href="/media/drawingboard.min.css">
@ -15,7 +15,7 @@
<div id="container">
<div id="top-nav">
<div id="logo">
<h1>Stable Diffusion UI <small>v2.195 <span id="updateBranchLabel"></span></small></h1>
<h1>Stable Diffusion UI <small>v2.2 <span id="updateBranchLabel"></span></small></h1>
</div>
<ul id="top-nav-items">
<li class="dropdown">
@ -213,7 +213,7 @@
</div>
</body>
<script src="media/main.js?v=14"></script>
<script src="media/main.js?v=21"></script>
<script>
async function init() {
await loadModifiers()

View File

@ -70,38 +70,36 @@ label {
font-size: 8pt;
}
.imgSeedLabel {
position: absolute;
transform: translateX(-100%);
margin-top: 5pt;
margin-left: -5pt;
font-size: 10pt;
background-color: #333;
opacity: 0.8;
color: #ddd;
border-radius: 3pt;
padding: 1pt 3pt;
}
.imgUseBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 30pt;
margin-left: -5pt;
}
.imgSaveBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 55pt;
margin-left: -5pt;
font-size: 0.8em;
background-color: rgb(44, 45, 48);
border-radius: 3px;
padding: 5px;
}
.imgItem {
display: inline;
padding-right: 10px;
display: inline-block;
margin-top: 1em;
margin-right: 1em;
}
.imgContainer {
display: flex;
justify-content: flex-end;
}
.imgItemInfo {
opacity: 0.5;
padding-bottom: 0.5em;
display: flex;
align-items: flex-end;
flex-direction: column;
position: absolute;
padding: 5px;
opacity: 0;
transition: 0.1s all;
}
.imgContainer:hover > .imgItemInfo {
opacity: 1;
}
.imgItemInfo * {
margin-bottom: 7px;
}
#container {
width: 90%;
margin-left: auto;
@ -409,4 +407,7 @@ img {
font-size: 10pt;
color: #aaa;
margin-bottom: 5pt;
}
.img-batch {
display: inline;
}

View File

@ -73,10 +73,10 @@ let editorTagsContainer = document.querySelector('#editor-inputs-tags-container'
let imagePreview = document.querySelector("#preview")
let previewImageField = document.querySelector('#preview-image')
previewImageField.onchange = () => changePreviewImages(previewImageField.value);
previewImageField.onchange = () => changePreviewImages(previewImageField.value)
let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider')
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value);
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value)
// let previewPrompt = document.querySelector('#preview-prompt')
@ -119,8 +119,8 @@ let bellPending = false
let taskQueue = []
let currentTask = null
const modifierThumbnailPath = 'media/modifier-thumbnails';
const activeCardClass = 'modifier-card-active';
const modifierThumbnailPath = 'media/modifier-thumbnails'
const activeCardClass = 'modifier-card-active'
function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key)
@ -202,7 +202,7 @@ function isStreamImageProgressEnabled() {
function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') {
return;
return
}
if (msgType == 'error') {
@ -259,18 +259,100 @@ async function healthCheck() {
}
}
function makeImageElement(width, height, outputContainer) {
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
function showImages(req, res, outputContainer, livePreview) {
let imageItemElements = outputContainer.querySelectorAll('.imgItem')
let img = document.createElement('img')
img.width = parseInt(width)
img.height = parseInt(height)
res.output.reverse()
imgItem.appendChild(img)
outputContainer.insertBefore(imgItem, outputContainer.firstChild)
res.output.forEach((result, index) => {
if(typeof res != 'object') return
return imgItem
const imageData = result?.data || result?.path + '?t=' + new Date().getTime(),
imageSeed = result?.seed,
imageWidth = req.width,
imageHeight = req.height;
if (!imageData.includes('/')) {
// res contained no data for the image, stop execution
setStatus('request', 'invalid image', 'error')
return
}
let imageItemElem = (index < imageItemElements.length ? imageItemElements[index] : null)
if(!imageItemElem) {
imageItemElem = document.createElement('div')
imageItemElem.className = 'imgItem'
imageItemElem.innerHTML = `
<div class="imgContainer">
<img/>
<div class="imgItemInfo">
<span class="imgSeedLabel"></span>
<button class="imgUseBtn">Use as Input</button>
<button class="imgSaveBtn">Download</button>
</div>
</div>
`
const useAsInputBtn = imageItemElem.querySelector('.imgUseBtn'),
saveImageBtn = imageItemElem.querySelector('.imgSaveBtn');
useAsInputBtn.addEventListener('click', getUseAsInputHandler(imageItemElem))
saveImageBtn.addEventListener('click', getSaveImageHandler(imageItemElem))
outputContainer.appendChild(imageItemElem)
}
const imageElem = imageItemElem.querySelector('img'),
imageSeedLabel = imageItemElem.querySelector('.imgSeedLabel');
imageElem.src = imageData
imageElem.width = parseInt(imageWidth)
imageElem.height = parseInt(imageHeight)
imageElem.setAttribute('data-seed', imageSeed)
const imageInfo = imageItemElem.querySelector('.imgItemInfo')
imageInfo.style.visibility = (livePreview ? 'hidden' : 'visible')
imageSeedLabel.innerText = 'Seed: ' + imageSeed
})
}
function getUseAsInputHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
initImageSelector.value = null
initImagePreview.src = imgData
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
samplerSelectionContainer.style.display = 'none'
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = imageSeed
seedField.disabled = false
}
}
function getSaveImageHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
const imgDownload = document.createElement('a')
imgDownload.download = createFileName(imageSeed)
imgDownload.href = imgData
imgDownload.click()
}
}
// makes a single image. don't call this directly, use makeImage() instead
@ -281,7 +363,10 @@ async function doMakeImage(task) {
const reqBody = task.reqBody
const batchCount = task.batchCount
const outputContainer = task.outputContainer
const outputContainer = document.createElement('div')
outputContainer.className = 'img-batch'
task.outputContainer.insertBefore(outputContainer, task.outputContainer.firstChild)
const outputMsg = task['outputMsg']
const previewPrompt = task['previewPrompt']
@ -291,14 +376,6 @@ async function doMakeImage(task) {
let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers(numImages) {
for (let i = images.length; i < numImages; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height, outputContainer))
}
}
try {
res = await fetch('/image', {
method: 'POST',
@ -351,14 +428,7 @@ async function doMakeImage(task) {
outputMsg.style.display = 'block'
if (stepUpdate.output !== undefined) {
makeImageContainers(numOutputs)
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
showImages(reqBody, stepUpdate, outputContainer, true)
}
}
} catch (e) {
@ -426,85 +496,11 @@ async function doMakeImage(task) {
res = undefined
}
if (!res) {
return false
}
if (!res) return false
lastPromptUsed = reqBody['prompt']
makeImageContainers(res.output.length)
for (let idx in res.output) {
let imgBody = ''
let seed = 0
try {
let imgData = res.output[idx]
imgBody = imgData.data
seed = imgData.seed
} catch (e) {
console.log(imgBody)
setStatus('request', 'invalid image', 'error')
continue
}
let imgItem = images[idx]
let img = imgItem.firstChild
img.src = imgBody
let imgItemInfo = document.createElement('span')
imgItemInfo.className = 'imgItemInfo'
imgItemInfo.style.opacity = 0
let imgSeedLabel = document.createElement('span')
imgSeedLabel.className = 'imgSeedLabel'
imgSeedLabel.innerText = 'Seed: ' + seed
let imgUseBtn = document.createElement('button')
imgUseBtn.className = 'imgUseBtn'
imgUseBtn.innerText = 'Use as Input'
let imgSaveBtn = document.createElement('button')
imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerText = 'Download'
imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = imgBody
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
seedField.disabled = false
})
imgSaveBtn.addEventListener('click', function() {
let imgDownload = document.createElement('a')
imgDownload.download = createFileName();
imgDownload.href = imgBody
imgDownload.click()
})
imgItem.addEventListener('mouseenter', function() {
imgItemInfo.style.opacity = 1
})
imgItem.addEventListener('mouseleave', function() {
imgItemInfo.style.opacity = 0
})
}
showImages(reqBody, res, outputContainer, false)
return true
}
@ -547,7 +543,6 @@ async function checkTasks() {
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].className += " activeTaskLabel"
console.log(task['taskStatusLabel'].className)
for (let i = 0; i < task.batchCount; i++) {
task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs'])
@ -576,7 +571,9 @@ async function checkTasks() {
// setStatus('request', 'done', 'success')
} else {
task.outputMsg.innerText = 'Task ended after ' + time + ' seconds'
if (task.outputMsg.innerText.toLowerCase().indexOf('error') === -1) {
task.outputMsg.innerText = 'Task ended after ' + time + ' seconds'
}
}
if (randomSeedField.checked) {
@ -610,8 +607,8 @@ async function makeImage() {
let prompt = promptField.value
if (activeTags.length > 0) {
let promptTags = activeTags.map(x => x.name).join(", ");
prompt += ", " + promptTags;
let promptTags = activeTags.map(x => x.name).join(", ")
prompt += ", " + promptTags
}
let reqBody = {
@ -734,12 +731,11 @@ async function makeImage() {
// create a file name with embedded prompt and metadata
// for easier cateloging and comparison
function createFileName() {
function createFileName(seed) {
// Most important information is the prompt
let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_')
underscoreName = underscoreName.substring(0, 100)
const seed = seedField.value
const steps = numInferenceStepsField.value
const guidance = guidanceScaleField.value
@ -747,20 +743,20 @@ function createFileName() {
let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}`
// add the tags
// let tags = [];
// let tagString = '';
// let tags = []
// let tagString = ''
// document.querySelectorAll(modifyTagsSelector).forEach(function(tag) {
// tags.push(tag.innerHTML);
// tags.push(tag.innerHTML)
// })
// join the tags with a pipe
// if (activeTags.length > 0) {
// tagString = '_Tags-';
// tagString += tags.join('|');
// tagString = '_Tags-'
// tagString += tags.join('|')
// }
// // append empty or populated tags
// fileName += `${tagString}`;
// fileName += `${tagString}`
// add the file extension
fileName += `.png`
@ -1035,25 +1031,25 @@ maskSetting.addEventListener('click', function() {
// https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) {
function numberEnding (number) {
return (number > 1) ? 's' : '';
return (number > 1) ? 's' : ''
}
var temp = Math.floor(milliseconds / 1000);
var hours = Math.floor((temp %= 86400) / 3600);
var temp = Math.floor(milliseconds / 1000)
var hours = Math.floor((temp %= 86400) / 3600)
var s = ''
if (hours) {
s += hours + ' hour' + numberEnding(hours) + ' ';
s += hours + ' hour' + numberEnding(hours) + ' '
}
var minutes = Math.floor((temp %= 3600) / 60);
var minutes = Math.floor((temp %= 3600) / 60)
if (minutes) {
s += minutes + ' minute' + numberEnding(minutes) + ' ';
s += minutes + ' minute' + numberEnding(minutes) + ' '
}
var seconds = temp % 60;
var seconds = temp % 60
if (!hours && minutes < 4 && seconds) {
s += seconds + ' second' + numberEnding(seconds);
s += seconds + ' second' + numberEnding(seconds)
}
return s;
return s
}
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
@ -1113,33 +1109,33 @@ function createCollapsibles(node) {
createCollapsibles()
function refreshTagsList() {
editorModifierTagsList.innerHTML = '';
editorModifierTagsList.innerHTML = ''
if (activeTags.length == 0) {
editorTagsContainer.style.display = 'none';
return;
editorTagsContainer.style.display = 'none'
return
} else {
editorTagsContainer.style.display = 'block';
editorTagsContainer.style.display = 'block'
}
activeTags.forEach((tag, index) => {
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-';
tag.element.classList.add('modifier-card-tiny');
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'
tag.element.classList.add('modifier-card-tiny')
editorModifierTagsList.appendChild(tag.element);
editorModifierTagsList.appendChild(tag.element)
tag.element.addEventListener('click', () => {
let idx = activeTags.indexOf(tag);
let idx = activeTags.indexOf(tag)
if (idx !== -1) {
activeTags[idx].originElement.classList.remove(activeCardClass);
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+';
activeTags[idx].originElement.classList.remove(activeCardClass)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
activeTags.splice(idx, 1);
refreshTagsList();
activeTags.splice(idx, 1)
refreshTagsList()
}
});
});
})
})
let brk = document.createElement('br')
brk.style.clear = 'both'
@ -1168,8 +1164,8 @@ async function getDiskPath() {
}
function createModifierCard(name, previews) {
const modifierCard = document.createElement('div');
modifierCard.className = 'modifier-card';
const modifierCard = document.createElement('div')
modifierCard.className = 'modifier-card'
modifierCard.innerHTML = `
<div class="modifier-card-overlay"></div>
<div class="modifier-card-image-container">
@ -1179,96 +1175,96 @@ function createModifierCard(name, previews) {
</div>
<div class="modifier-card-container">
<div class="modifier-card-label"><p></p></div>
</div>`;
</div>`
const image = modifierCard.querySelector('.modifier-card-image');
const errorText = modifierCard.querySelector('.modifier-card-error-label');
const label = modifierCard.querySelector('.modifier-card-label');
const image = modifierCard.querySelector('.modifier-card-image')
const errorText = modifierCard.querySelector('.modifier-card-error-label')
const label = modifierCard.querySelector('.modifier-card-label')
errorText.innerText = 'No Image';
errorText.innerText = 'No Image'
if (typeof previews == 'object') {
image.src = previews[0]; // portrait
image.setAttribute('preview-type', 'portrait');
image.setAttribute('preview-type', 'portrait')
} else {
image.remove();
image.remove()
}
const maxLabelLength = 30;
const nameWithoutBy = name.replace('by ', '');
const maxLabelLength = 30
const nameWithoutBy = name.replace('by ', '')
if(nameWithoutBy.length <= maxLabelLength) {
label.querySelector('p').innerText = nameWithoutBy;
label.querySelector('p').innerText = nameWithoutBy
} else {
const tooltipText = document.createElement('span');
tooltipText.className = 'tooltip-text';
tooltipText.innerText = name;
const tooltipText = document.createElement('span')
tooltipText.className = 'tooltip-text'
tooltipText.innerText = name
label.classList.add('tooltip');
label.appendChild(tooltipText);
label.classList.add('tooltip')
label.appendChild(tooltipText)
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...';
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'
}
return modifierCard;
return modifierCard
}
function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img');
const previewImages = document.querySelectorAll('.modifier-card-image-container img')
let previewArr = [];
let previewArr = []
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)));
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)))
previewArr = previewArr.map(x => {
let obj = {};
let obj = {}
x.forEach(preview => {
obj[preview.name] = preview.path;
});
obj[preview.name] = preview.path
})
return obj;
});
return obj
})
previewImages.forEach(previewImage => {
const currentPreviewType = previewImage.getAttribute('preview-type');
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop();
const currentPreviewType = previewImage.getAttribute('preview-type')
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop()
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType]);
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType])
if(typeof previews == 'object') {
let preview = null;
let preview = null
if (val == 'portrait') {
preview = previews.portrait;
preview = previews.portrait
}
else if (val == 'landscape') {
preview = previews.landscape;
preview = previews.landscape
}
if(preview != null) {
previewImage.src = `${modifierThumbnailPath}/${preview}`;
previewImage.setAttribute('preview-type', val);
previewImage.src = `${modifierThumbnailPath}/${preview}`
previewImage.setAttribute('preview-type', val)
}
}
});
})
}
function resizeModifierCards(val) {
const cardSizePrefix = 'modifier-card-size_';
const modifierCardClass = 'modifier-card';
const cardSizePrefix = 'modifier-card-size_'
const modifierCardClass = 'modifier-card'
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`);
const cardSize = n => `${cardSizePrefix}${n}`;
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`)
const cardSize = n => `${cardSizePrefix}${n}`
modifierCards.forEach(card => {
// remove existing size classes
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix));
card.className = classes.join(' ').trim();
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
card.className = classes.join(' ').trim()
if(val != 0)
card.classList.add(cardSize(val));
});
card.classList.add(cardSize(val))
})
}
async function loadModifiers() {
@ -1280,15 +1276,15 @@ async function loadModifiers() {
modifiers = res; // update global variable
res.forEach((modifierGroup, idx) => {
const title = modifierGroup.category;
const modifiers = modifierGroup.modifiers;
const title = modifierGroup.category
const modifiers = modifierGroup.modifiers
const titleEl = document.createElement('h5');
titleEl.className = 'collapsible';
titleEl.innerText = title;
const titleEl = document.createElement('h5')
titleEl.className = 'collapsible'
titleEl.innerText = title
const modifiersEl = document.createElement('div');
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf');
const modifiersEl = document.createElement('div')
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf')
if (idx == 0) {
titleEl.className += ' active'
@ -1296,21 +1292,21 @@ async function loadModifiers() {
}
modifiers.forEach(modObj => {
const modifierName = modObj.modifier;
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`);
const modifierName = modObj.modifier
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`)
const modifierCard = createModifierCard(modifierName, modifierPreviews);
const modifierCard = createModifierCard(modifierName, modifierPreviews)
if(typeof modifierCard == 'object') {
modifiersEl.appendChild(modifierCard);
modifiersEl.appendChild(modifierCard)
modifierCard.addEventListener('click', () => {
if (activeTags.map(x => x.name).includes(modifierName)) {
// remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName);
modifierCard.classList.remove(activeCardClass);
activeTags = activeTags.filter(x => x.name != modifierName)
modifierCard.classList.remove(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+';
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
} else {
// add modifier to active array
activeTags.push({
@ -1318,17 +1314,17 @@ async function loadModifiers() {
'element': modifierCard.cloneNode(true),
'originElement': modifierCard,
'previews': modifierPreviews
});
})
modifierCard.classList.add(activeCardClass);
modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-';
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
}
refreshTagsList();
});
refreshTagsList()
})
}
});
})
let brk = document.createElement('br')
brk.style.clear = 'both'
@ -1346,4 +1342,4 @@ async function loadModifiers() {
} catch (e) {
console.log('error fetching modifiers', e)
}
}
}

View File

@ -23,6 +23,7 @@ class Request:
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # "png", "jpeg"
stream_progress_updates: bool = False
stream_image_progress: bool = False
@ -42,6 +43,7 @@ class Request:
"sampler": self.sampler,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"output_format": self.output_format,
}
def to_string(self):
@ -63,6 +65,7 @@ class Request:
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''

View File

@ -0,0 +1,46 @@
diff --git a/ldm/dream/conditioning.py b/ldm/dream/conditioning.py
index dfa1089..e4908ad 100644
--- a/ldm/dream/conditioning.py
+++ b/ldm/dream/conditioning.py
@@ -12,8 +12,8 @@ log_tokenization() print out colour-coded tokens and warn if trunca
import re
import torch
-def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
- uc = model.get_learned_conditioning([''])
+def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False, negative_prompt=''):
+ uc = model.get_learned_conditioning([negative_prompt])
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
diff --git a/ldm/generate.py b/ldm/generate.py
index 8f67403..d88ce2d 100644
--- a/ldm/generate.py
+++ b/ldm/generate.py
@@ -205,6 +205,7 @@ class Generate:
init_mask = None,
fit = False,
strength = None,
+ init_img_is_path = True,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
@@ -303,11 +304,15 @@ class Generate:
uc, c = get_uc_and_c(
prompt, model=self.model,
skip_normalize=skip_normalize,
- log_tokens=self.log_tokenization
+ log_tokens=self.log_tokenization,
+ negative_prompt=(args['negative_prompt'] if 'negative_prompt' in args else '')
)
- (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
-
+ if init_img_is_path:
+ (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
+ else:
+ (init_image,mask_image) = (init_img, init_mask)
+
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif init_image is not None:

View File

@ -1,64 +1,47 @@
import json
import os, re
import traceback
import sys
import os
import uuid
import re
import torch
import traceback
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from pytorch_lightning import logging
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from optimizedSD.optimUtils import split_weighted_subprompts
from transformers import logging
from PIL import Image, ImageOps, ImageChops
from ldm.generate import Generate
import transformers
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
import uuid
transformers.logging.set_verbosity_error()
logging.set_verbosity_error()
# consts
config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]')
# api stuff
from . import Request, Response, Image as ResponseImage
import base64
import json
from io import BytesIO
#from colorama import Fore
filename_regex = re.compile('[^a-zA-Z0-9]')
generator = None
gfpgan_file = None
real_esrgan_file = None
model_gfpgan = None
model_real_esrgan = None
device = None
precision = 'autocast'
has_valid_gpu = False
force_full_precision = False
# local
stop_processing = False
temp_images = {}
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
model = None
modelCS = None
modelFS = None
model_gfpgan = None
model_real_esrgan = None
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
has_valid_gpu = False
force_full_precision = False
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
@ -79,68 +62,45 @@ except:
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', precision_to_use='autocast'):
global generator
ckpt_file = ckpt_to_use
device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full'
unet_bs = unet_bs_to_use
if device == 'cpu':
precision = 'full'
try:
config = 'configs/models.yaml'
model = 'stable-diffusion-1.4'
sd = load_model_from_config(f"{ckpt_file}.ckpt")
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)
models = OmegaConf.load(config)
width = models[model].width
height = models[model].height
config = models[model].config
weights = ckpt_to_use + '.ckpt'
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
config = OmegaConf.load(f"{config_yaml}")
generator = Generate(
width=width,
height=height,
sampler_name='ddim',
weights=weights,
full_precision=(precision == 'full'),
config=config,
grid=False,
# this is solely for recreating the prompt
seamless=False,
embedding_path=None,
device_type=device,
ignore_ctrl_c=True,
)
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = turbo
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = device
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
model_is_half = True
else:
model_is_half = False
if half_model_fs:
modelFS.half()
model_fs_is_half = True
else:
model_fs_is_half = False
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
# preload the model
generator.load_model()
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
@ -179,7 +139,7 @@ def load_model_real_esrgan(real_esrgan_to_use):
model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu')
else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=(precision != 'full'))
model_real_esrgan.model.name = real_esrgan_to_use
@ -193,14 +153,14 @@ def mk_img(req: Request):
gc()
if device != "cpu":
modelFS.to("cpu")
modelCS.to("cpu")
# if device != "cpu":
# modelFS.to("cpu")
# modelCS.to("cpu")
model.model1.to("cpu")
model.model2.to("cpu")
# model.model1.to("cpu")
# model.model2.to("cpu")
gc()
# gc()
yield json.dumps({
"status": 'failed',
@ -208,292 +168,164 @@ def mk_img(req: Request):
})
def do_mk_img(req: Request):
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
res = Response()
res.request = req
res.images = []
temp_images.clear()
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
device = 'cpu'
if model_is_half:
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device)
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \
(req.init_image is None and model_fs_is_half) or \
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale)
model.cdevice = device
modelCS.cond_stage_model.device = device
init_image = None
init_mask = None
opt_prompt = req.prompt
opt_seed = req.seed
opt_n_samples = req.num_outputs
opt_n_iter = 1
opt_scale = req.guidance_scale
opt_C = 4
opt_H = req.height
opt_W = req.width
opt_f = 8
opt_ddim_steps = req.num_inference_steps
opt_ddim_eta = 0.0
opt_strength = req.prompt_strength
opt_save_to_disk_path = req.save_to_disk_path
opt_init_img = req.init_image
opt_use_face_correction = req.use_face_correction
opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png'
opt_sampler_name = req.sampler
if req.init_image is not None:
image = base64_str_to_img(req.init_image)
print(req.to_string(), '\n device', device)
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
print('\n\n Using precision:', precision)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
init_image = generator._create_init_image(image)
seed_everything(opt_seed)
if generator._has_transparency(image) and req.mask is None: # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.')
if generator._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
init_mask = generator._create_init_mask(image) # this returns a torch tensor
batch_size = opt_n_samples
prompt = opt_prompt
assert prompt is not None
data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
mask = None
if req.init_image is None:
handler = _txt2img
init_latent = None
t_enc = None
else:
handler = _img2img
init_image = load_img(req.init_image, opt_W, opt_H)
init_image = init_image.to(device)
if device != "cpu" and precision == "autocast":
if device != "cpu" and precision != "full":
init_image = init_image.half()
modelFS.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
image = base64_str_to_img(req.mask)
if device != "cpu" and precision == "autocast":
mask = mask.half()
image = ImageChops.invert(image)
move_fs_to_cpu()
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(opt_strength * opt_ddim_steps)
print(f"target t_enc is {t_enc} steps")
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
if opt_save_to_disk_path is not None:
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
init_mask = generator._create_init_mask(image)
if init_mask is not None:
req.sampler = 'plms' # hack to force the underlying implementation to initialize DDIM properly
result = generator.prompt2image(
req.prompt,
iterations = req.num_outputs,
steps = req.num_inference_steps,
seed = req.seed,
cfg_scale = req.guidance_scale,
ddim_eta = 0.0,
skip_normalize = False,
image_callback = None,
step_callback = None,
width = req.width,
height = req.height,
sampler_name = req.sampler,
seamless = False,
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
# these are specific to img2img and inpaint
init_img = init_image,
init_mask = init_mask,
fit = False,
strength = req.prompt_strength,
init_img_is_path = False,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
upscale = None,
negative_prompt= req.negative_prompt,
)
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
print('has filter', has_filters)
return_orig_img = not has_filters or not req.show_only_filtered_image
res = Response()
res.request = req
res.images = []
if req.save_to_disk_path is not None:
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True)
else:
session_out_path = None
seeds = ""
with torch.no_grad():
for n in trange(opt_n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
for img, seed in result:
if req.save_to_disk_path is not None:
prompt_flattened = filename_regex.sub('_', req.prompt)
prompt_flattened = prompt_flattened[:50]
with precision_scope("cuda"):
modelCS.to(device)
uc = None
if opt_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
img_id = str(uuid.uuid4())[-8:]
subprompts, weights = split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
file_path = f"{prompt_flattened}_{img_id}"
img_out_path = os.path.join(session_out_path, f"{file_path}.{req.output_format}")
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
modelFS.to(device)
if return_orig_img:
save_image(img, img_out_path)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
save_metadata(meta_out_path, req.prompt, seed, req.width, req.height, req.num_inference_steps, req.guidance_scale, req.prompt_strength, req.use_face_correction, req.use_upscale, req.sampler, req.negative_prompt)
partial_x_samples = x_samples
if return_orig_img:
img_data = img_to_base64_str(img)
res_image_orig = ResponseImage(data=img_data, seed=seed)
res.images.append(res_image_orig)
if req.stream_progress_updates:
n_steps = opt_ddim_steps if req.init_image is None else t_enc
progress = {"step": i, "total_steps": n_steps}
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
if req.stream_image_progress and i % 5 == 0:
partial_images = []
if has_filters and not stop_processing:
print('Applying filters..')
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
gc()
filters_applied = []
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
np_img = img.convert('RGB')
np_img = np.array(np_img, dtype=np.uint8)
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
if req.use_face_correction:
_, _, np_img = model_gfpgan.enhance(np_img, has_aligned=False, only_center_face=False, paste_back=True)
filters_applied.append(req.use_face_correction)
progress['output'] = partial_images
if req.use_upscale:
np_img, _ = model_real_esrgan.enhance(np_img)
filters_applied.append(req.use_upscale)
yield json.dumps(progress)
filtered_image = Image.fromarray(np_img)
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
filtered_img_data = img_to_base64_str(filtered_image)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=seed)
res.images.append(res_image_filtered)
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
filters_applied = "_".join(filters_applied)
yield from x_samples
if req.save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{req.output_format}")
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
x_samples = partial_x_samples
except UserInitiatedStop:
if partial_x_samples is None:
continue
x_samples = partial_x_samples
print("saving images")
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not opt_show_only_filtered
if stop_processing:
return_orig_img = True
if opt_save_to_disk_path is not None:
prompt_flattened = filename_regex.sub('_', prompts[0])
prompt_flattened = prompt_flattened[:50]
img_id = str(uuid.uuid4())[-8:]
file_path = f"{prompt_flattened}_{img_id}"
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
if return_orig_img:
save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt)
if return_orig_img:
img_data = img_to_base64_str(img)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig)
if opt_save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
del img
if has_filters and not stop_processing:
print('Applying filters..')
gc()
filters_applied = []
if opt_use_face_correction:
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
x_sample = output[:,:,::-1]
filters_applied.append(opt_use_face_correction)
if opt_use_upscale:
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
x_sample = output[:,:,::-1]
filters_applied.append(opt_use_upscale)
filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(res_image_filtered)
filters_applied = "_".join(filters_applied)
if opt_save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
del filtered_image
seeds += str(opt_seed) + ","
opt_seed += 1
move_fs_to_cpu()
gc()
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
del filtered_image
del img
print('Task completed')
@ -505,8 +337,8 @@ def save_image(img, img_out_path):
except:
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
def save_metadata(meta_out_path, prompt, seed, width, height, num_inference_steps, guidance_scale, prompt_strength, use_correct_face, use_upscale, sampler_name, negative_prompt):
metadata = f"{prompt}\nWidth: {width}\nHeight: {height}\nSeed: {seed}\nSteps: {num_inference_steps}\nGuidance Scale: {guidance_scale}\nPrompt Strength: {prompt_strength}\nUse Face Correction: {use_correct_face}\nUse Upscaling: {use_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
try:
with open(meta_out_path, 'w') as f:
@ -514,68 +346,6 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
def gc():
if device == 'cpu':
return
@ -583,25 +353,6 @@ def gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# internal
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
# utils
class UserInitiatedStop(Exception):
pass
def load_img(img_str, w0, h0):
image = base64_str_to_img(img_str).convert("RGB")
w, h = image.size

View File

@ -58,6 +58,7 @@ class ImageRequest(BaseModel):
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # "png", "jpeg"
stream_progress_updates: bool = False
stream_image_progress: bool = False
@ -123,6 +124,7 @@ def image(req : ImageRequest):
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress