Merge pull request #292 from cmdr2/beta

Custom Models; Reduce RAM usage; JPG/PNG option
This commit is contained in:
cmdr2 2022-10-07 19:45:17 +05:30 committed by GitHub
commit 23d20c918f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 421 additions and 266 deletions

View File

@ -165,6 +165,9 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
@if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
echo "Data files (weights) necessary for Stable Diffusion were already downloaded. Using the HuggingFace 4 GB Model."

View File

@ -161,6 +161,9 @@ fi
mkdir -p "../models/stable-diffusion"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"`

View File

@ -4,7 +4,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16">
<link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32">
<link rel="stylesheet" href="/media/main.css?v=10">
<link rel="stylesheet" href="/media/main.css?v=21">
<link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
<link rel="stylesheet" href="/media/drawingboard.min.css">
@ -15,7 +15,7 @@
<div id="container">
<div id="top-nav">
<div id="logo">
<h1>Stable Diffusion UI <small>v2.195 <span id="updateBranchLabel"></span></small></h1>
<h1>Stable Diffusion UI <small>v2.2 <span id="updateBranchLabel"></span></small></h1>
</div>
<ul id="top-nav-items">
<li class="dropdown">
@ -89,6 +89,11 @@
<li><b class="settings-subheader">Image Settings</b></li>
<li class="pl-5"><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li>
<li class="pl-5"><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="1"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1"> (images at once)</li>
<li class="pl-5"><label for="stable_diffusion_model">Model:</label>
<select id="stable_diffusion_model" name="stable_diffusion_model">
<!-- <option value="sd-v1-4" selected>sd-v1-4</option> -->
</select>
</li>
<li id="samplerSelection" class="pl-5"><label for="sampler">Sampler:</label>
<select id="sampler" name="sampler">
<option value="plms" selected>plms</option>
@ -150,6 +155,12 @@
<li class="pl-5"><label for="num_inference_steps">Number of inference steps:</label> <input id="num_inference_steps" name="num_inference_steps" size="4" value="50"></li>
<li class="pl-5"><label for="guidance_scale_slider">Guidance Scale:</label> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="10" max="500"> <input id="guidance_scale" name="guidance_scale" size="4"></li>
<li class="pl-5"><span id="prompt_strength_container"><label for="prompt_strength_slider">Prompt Strength:</label> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4"><br/></span></li>
<li class="pl-5"><label for="output_format">Output format:</label>
<select id="output_format" name="output_format">
<option value="jpeg" selected>jpeg</option>
<option value="png">png</option>
</select>
</li>
<br/>
@ -213,12 +224,13 @@
</div>
</body>
<script src="media/main.js?v=15"></script>
<script src="media/main.js?v=31"></script>
<script>
async function init() {
await loadModifiers()
await getDiskPath()
await getAppConfig()
await getModels()
setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000)
healthCheck()

View File

@ -70,38 +70,36 @@ label {
font-size: 8pt;
}
.imgSeedLabel {
position: absolute;
transform: translateX(-100%);
margin-top: 5pt;
margin-left: -5pt;
font-size: 10pt;
background-color: #333;
opacity: 0.8;
color: #ddd;
border-radius: 3pt;
padding: 1pt 3pt;
}
.imgUseBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 30pt;
margin-left: -5pt;
}
.imgSaveBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 55pt;
margin-left: -5pt;
font-size: 0.8em;
background-color: rgb(44, 45, 48);
border-radius: 3px;
padding: 5px;
}
.imgItem {
display: inline;
padding-right: 10px;
display: inline-block;
margin-top: 1em;
margin-right: 1em;
}
.imgContainer {
display: flex;
justify-content: flex-end;
}
.imgItemInfo {
opacity: 0.5;
padding-bottom: 0.5em;
display: flex;
align-items: flex-end;
flex-direction: column;
position: absolute;
padding: 5px;
opacity: 0;
transition: 0.1s all;
}
.imgContainer:hover > .imgItemInfo {
opacity: 1;
}
.imgItemInfo * {
margin-bottom: 7px;
}
#container {
width: 90%;
margin-left: auto;
@ -409,4 +407,7 @@ img {
font-size: 10pt;
color: #aaa;
margin-bottom: 5pt;
}
.img-batch {
display: inline;
}

View File

@ -46,6 +46,8 @@ let samplerSelectionContainer = document.querySelector("#samplerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model")
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
let outputFormatField = document.querySelector('#output_format')
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel")
let streamImageProgressField = document.querySelector("#stream_image_progress")
@ -73,10 +75,10 @@ let editorTagsContainer = document.querySelector('#editor-inputs-tags-container'
let imagePreview = document.querySelector("#preview")
let previewImageField = document.querySelector('#preview-image')
previewImageField.onchange = () => changePreviewImages(previewImageField.value);
previewImageField.onchange = () => changePreviewImages(previewImageField.value)
let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider')
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value);
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value)
// let previewPrompt = document.querySelector('#preview-prompt')
@ -119,8 +121,8 @@ let bellPending = false
let taskQueue = []
let currentTask = null
const modifierThumbnailPath = 'media/modifier-thumbnails';
const activeCardClass = 'modifier-card-active';
const modifierThumbnailPath = 'media/modifier-thumbnails'
const activeCardClass = 'modifier-card-active'
function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key)
@ -202,7 +204,7 @@ function isStreamImageProgressEnabled() {
function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') {
return;
return
}
if (msgType == 'error') {
@ -259,18 +261,100 @@ async function healthCheck() {
}
}
function makeImageElement(width, height, outputContainer) {
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
function showImages(req, res, outputContainer, livePreview) {
let imageItemElements = outputContainer.querySelectorAll('.imgItem')
let img = document.createElement('img')
img.width = parseInt(width)
img.height = parseInt(height)
res.output.reverse()
imgItem.appendChild(img)
outputContainer.insertBefore(imgItem, outputContainer.firstChild)
res.output.forEach((result, index) => {
if(typeof res != 'object') return
return imgItem
const imageData = result?.data || result?.path + '?t=' + new Date().getTime(),
imageSeed = result?.seed,
imageWidth = req.width,
imageHeight = req.height;
if (!imageData.includes('/')) {
// res contained no data for the image, stop execution
setStatus('request', 'invalid image', 'error')
return
}
let imageItemElem = (index < imageItemElements.length ? imageItemElements[index] : null)
if(!imageItemElem) {
imageItemElem = document.createElement('div')
imageItemElem.className = 'imgItem'
imageItemElem.innerHTML = `
<div class="imgContainer">
<img/>
<div class="imgItemInfo">
<span class="imgSeedLabel"></span>
<button class="imgUseBtn">Use as Input</button>
<button class="imgSaveBtn">Download</button>
</div>
</div>
`
const useAsInputBtn = imageItemElem.querySelector('.imgUseBtn'),
saveImageBtn = imageItemElem.querySelector('.imgSaveBtn');
useAsInputBtn.addEventListener('click', getUseAsInputHandler(imageItemElem))
saveImageBtn.addEventListener('click', getSaveImageHandler(imageItemElem, req['output_format']))
outputContainer.appendChild(imageItemElem)
}
const imageElem = imageItemElem.querySelector('img'),
imageSeedLabel = imageItemElem.querySelector('.imgSeedLabel');
imageElem.src = imageData
imageElem.width = parseInt(imageWidth)
imageElem.height = parseInt(imageHeight)
imageElem.setAttribute('data-seed', imageSeed)
const imageInfo = imageItemElem.querySelector('.imgItemInfo')
imageInfo.style.visibility = (livePreview ? 'hidden' : 'visible')
imageSeedLabel.innerText = 'Seed: ' + imageSeed
})
}
function getUseAsInputHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
initImageSelector.value = null
initImagePreview.src = imgData
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
samplerSelectionContainer.style.display = 'none'
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = imageSeed
seedField.disabled = false
}
}
function getSaveImageHandler(imageItemElem, outputFormat) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
const imgDownload = document.createElement('a')
imgDownload.download = createFileName(imageSeed, outputFormat)
imgDownload.href = imgData
imgDownload.click()
}
}
// makes a single image. don't call this directly, use makeImage() instead
@ -281,7 +365,10 @@ async function doMakeImage(task) {
const reqBody = task.reqBody
const batchCount = task.batchCount
const outputContainer = task.outputContainer
const outputContainer = document.createElement('div')
outputContainer.className = 'img-batch'
task.outputContainer.insertBefore(outputContainer, task.outputContainer.firstChild)
const outputMsg = task['outputMsg']
const previewPrompt = task['previewPrompt']
@ -291,14 +378,6 @@ async function doMakeImage(task) {
let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers(numImages) {
for (let i = images.length; i < numImages; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height, outputContainer))
}
}
try {
res = await fetch('/image', {
method: 'POST',
@ -351,14 +430,7 @@ async function doMakeImage(task) {
outputMsg.style.display = 'block'
if (stepUpdate.output !== undefined) {
makeImageContainers(numOutputs)
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
showImages(reqBody, stepUpdate, outputContainer, true)
}
}
} catch (e) {
@ -426,85 +498,11 @@ async function doMakeImage(task) {
res = undefined
}
if (!res) {
return false
}
if (!res) return false
lastPromptUsed = reqBody['prompt']
makeImageContainers(res.output.length)
for (let idx in res.output) {
let imgBody = ''
let seed = 0
try {
let imgData = res.output[idx]
imgBody = imgData.data
seed = imgData.seed
} catch (e) {
console.log(imgBody)
setStatus('request', 'invalid image', 'error')
continue
}
let imgItem = images[idx]
let img = imgItem.firstChild
img.src = imgBody
let imgItemInfo = document.createElement('span')
imgItemInfo.className = 'imgItemInfo'
imgItemInfo.style.opacity = 0
let imgSeedLabel = document.createElement('span')
imgSeedLabel.className = 'imgSeedLabel'
imgSeedLabel.innerText = 'Seed: ' + seed
let imgUseBtn = document.createElement('button')
imgUseBtn.className = 'imgUseBtn'
imgUseBtn.innerText = 'Use as Input'
let imgSaveBtn = document.createElement('button')
imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerText = 'Download'
imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = imgBody
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
seedField.disabled = false
})
imgSaveBtn.addEventListener('click', function() {
let imgDownload = document.createElement('a')
imgDownload.download = createFileName();
imgDownload.href = imgBody
imgDownload.click()
})
imgItem.addEventListener('mouseenter', function() {
imgItemInfo.style.opacity = 1
})
imgItem.addEventListener('mouseleave', function() {
imgItemInfo.style.opacity = 0
})
}
showImages(reqBody, res, outputContainer, false)
return true
}
@ -547,7 +545,6 @@ async function checkTasks() {
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].className += " activeTaskLabel"
console.log(task['taskStatusLabel'].className)
for (let i = 0; i < task.batchCount; i++) {
task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs'])
@ -612,8 +609,8 @@ async function makeImage() {
let prompt = promptField.value
if (activeTags.length > 0) {
let promptTags = activeTags.map(x => x.name).join(", ");
prompt += ", " + promptTags;
let promptTags = activeTags.map(x => x.name).join(", ")
prompt += ", " + promptTags
}
let reqBody = {
@ -629,9 +626,11 @@ async function makeImage() {
turbo: turboField.checked,
use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value,
stream_progress_updates: true,
stream_image_progress: streamImageProgress,
show_only_filtered_image: showOnlyFilteredImageField.checked
show_only_filtered_image: showOnlyFilteredImageField.checked,
output_format: outputFormatField.value
}
if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -662,7 +661,7 @@ async function makeImage() {
reqBody['use_upscale'] = upscaleModelField.value
}
let taskConfig = `Seed: ${seed}, Sampler: ${reqBody['sampler']}, Inference Steps: ${numInferenceStepsField.value}, Guidance Scale: ${guidanceScaleField.value}`
let taskConfig = `Seed: ${seed}, Sampler: ${reqBody['sampler']}, Inference Steps: ${numInferenceStepsField.value}, Guidance Scale: ${guidanceScaleField.value}, Model: ${stableDiffusionModelField.value}`
if (negativePromptField.value.trim() !== '') {
taskConfig += `, Negative Prompt: ${negativePromptField.value.trim()}`
@ -736,12 +735,11 @@ async function makeImage() {
// create a file name with embedded prompt and metadata
// for easier cateloging and comparison
function createFileName() {
function createFileName(seed, outputFormat) {
// Most important information is the prompt
let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_')
underscoreName = underscoreName.substring(0, 100)
const seed = seedField.value
const steps = numInferenceStepsField.value
const guidance = guidanceScaleField.value
@ -749,23 +747,23 @@ function createFileName() {
let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}`
// add the tags
// let tags = [];
// let tagString = '';
// let tags = []
// let tagString = ''
// document.querySelectorAll(modifyTagsSelector).forEach(function(tag) {
// tags.push(tag.innerHTML);
// tags.push(tag.innerHTML)
// })
// join the tags with a pipe
// if (activeTags.length > 0) {
// tagString = '_Tags-';
// tagString += tags.join('|');
// tagString = '_Tags-'
// tagString += tags.join('|')
// }
// // append empty or populated tags
// fileName += `${tagString}`;
// fileName += `${tagString}`
// add the file extension
fileName += `.png`
fileName += '.' + (outputFormat === 'png' ? 'png' : 'jpeg')
return fileName
}
@ -939,6 +937,33 @@ async function getAppConfig() {
}
}
async function getModels() {
try {
let res = await fetch('/models')
models = await res.json()
let activeModel = models['active']
let modelOptions = models['options']
let stableDiffusionOptions = modelOptions['stable-diffusion']
stableDiffusionOptions.forEach(modelName => {
let modelOption = document.createElement('option')
modelOption.value = modelName
modelOption.innerText = modelName
if (modelName === activeModel['stable-diffusion']) {
modelOption.selected = true
}
stableDiffusionModelField.appendChild(modelOption)
})
console.log('get models response', config)
} catch (e) {
console.log('get models error', e)
}
}
function checkRandomSeed() {
if (randomSeedField.checked) {
seedField.disabled = true
@ -1037,25 +1062,25 @@ maskSetting.addEventListener('click', function() {
// https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) {
function numberEnding (number) {
return (number > 1) ? 's' : '';
return (number > 1) ? 's' : ''
}
var temp = Math.floor(milliseconds / 1000);
var hours = Math.floor((temp %= 86400) / 3600);
var temp = Math.floor(milliseconds / 1000)
var hours = Math.floor((temp %= 86400) / 3600)
var s = ''
if (hours) {
s += hours + ' hour' + numberEnding(hours) + ' ';
s += hours + ' hour' + numberEnding(hours) + ' '
}
var minutes = Math.floor((temp %= 3600) / 60);
var minutes = Math.floor((temp %= 3600) / 60)
if (minutes) {
s += minutes + ' minute' + numberEnding(minutes) + ' ';
s += minutes + ' minute' + numberEnding(minutes) + ' '
}
var seconds = temp % 60;
var seconds = temp % 60
if (!hours && minutes < 4 && seconds) {
s += seconds + ' second' + numberEnding(seconds);
s += seconds + ' second' + numberEnding(seconds)
}
return s;
return s
}
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
@ -1115,33 +1140,33 @@ function createCollapsibles(node) {
createCollapsibles()
function refreshTagsList() {
editorModifierTagsList.innerHTML = '';
editorModifierTagsList.innerHTML = ''
if (activeTags.length == 0) {
editorTagsContainer.style.display = 'none';
return;
editorTagsContainer.style.display = 'none'
return
} else {
editorTagsContainer.style.display = 'block';
editorTagsContainer.style.display = 'block'
}
activeTags.forEach((tag, index) => {
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-';
tag.element.classList.add('modifier-card-tiny');
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'
tag.element.classList.add('modifier-card-tiny')
editorModifierTagsList.appendChild(tag.element);
editorModifierTagsList.appendChild(tag.element)
tag.element.addEventListener('click', () => {
let idx = activeTags.indexOf(tag);
let idx = activeTags.indexOf(tag)
if (idx !== -1) {
activeTags[idx].originElement.classList.remove(activeCardClass);
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+';
activeTags[idx].originElement.classList.remove(activeCardClass)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
activeTags.splice(idx, 1);
refreshTagsList();
activeTags.splice(idx, 1)
refreshTagsList()
}
});
});
})
})
let brk = document.createElement('br')
brk.style.clear = 'both'
@ -1170,8 +1195,8 @@ async function getDiskPath() {
}
function createModifierCard(name, previews) {
const modifierCard = document.createElement('div');
modifierCard.className = 'modifier-card';
const modifierCard = document.createElement('div')
modifierCard.className = 'modifier-card'
modifierCard.innerHTML = `
<div class="modifier-card-overlay"></div>
<div class="modifier-card-image-container">
@ -1181,96 +1206,96 @@ function createModifierCard(name, previews) {
</div>
<div class="modifier-card-container">
<div class="modifier-card-label"><p></p></div>
</div>`;
</div>`
const image = modifierCard.querySelector('.modifier-card-image');
const errorText = modifierCard.querySelector('.modifier-card-error-label');
const label = modifierCard.querySelector('.modifier-card-label');
const image = modifierCard.querySelector('.modifier-card-image')
const errorText = modifierCard.querySelector('.modifier-card-error-label')
const label = modifierCard.querySelector('.modifier-card-label')
errorText.innerText = 'No Image';
errorText.innerText = 'No Image'
if (typeof previews == 'object') {
image.src = previews[0]; // portrait
image.setAttribute('preview-type', 'portrait');
image.setAttribute('preview-type', 'portrait')
} else {
image.remove();
image.remove()
}
const maxLabelLength = 30;
const nameWithoutBy = name.replace('by ', '');
const maxLabelLength = 30
const nameWithoutBy = name.replace('by ', '')
if(nameWithoutBy.length <= maxLabelLength) {
label.querySelector('p').innerText = nameWithoutBy;
label.querySelector('p').innerText = nameWithoutBy
} else {
const tooltipText = document.createElement('span');
tooltipText.className = 'tooltip-text';
tooltipText.innerText = name;
const tooltipText = document.createElement('span')
tooltipText.className = 'tooltip-text'
tooltipText.innerText = name
label.classList.add('tooltip');
label.appendChild(tooltipText);
label.classList.add('tooltip')
label.appendChild(tooltipText)
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...';
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'
}
return modifierCard;
return modifierCard
}
function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img');
const previewImages = document.querySelectorAll('.modifier-card-image-container img')
let previewArr = [];
let previewArr = []
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)));
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)))
previewArr = previewArr.map(x => {
let obj = {};
let obj = {}
x.forEach(preview => {
obj[preview.name] = preview.path;
});
obj[preview.name] = preview.path
})
return obj;
});
return obj
})
previewImages.forEach(previewImage => {
const currentPreviewType = previewImage.getAttribute('preview-type');
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop();
const currentPreviewType = previewImage.getAttribute('preview-type')
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop()
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType]);
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType])
if(typeof previews == 'object') {
let preview = null;
let preview = null
if (val == 'portrait') {
preview = previews.portrait;
preview = previews.portrait
}
else if (val == 'landscape') {
preview = previews.landscape;
preview = previews.landscape
}
if(preview != null) {
previewImage.src = `${modifierThumbnailPath}/${preview}`;
previewImage.setAttribute('preview-type', val);
previewImage.src = `${modifierThumbnailPath}/${preview}`
previewImage.setAttribute('preview-type', val)
}
}
});
})
}
function resizeModifierCards(val) {
const cardSizePrefix = 'modifier-card-size_';
const modifierCardClass = 'modifier-card';
const cardSizePrefix = 'modifier-card-size_'
const modifierCardClass = 'modifier-card'
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`);
const cardSize = n => `${cardSizePrefix}${n}`;
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`)
const cardSize = n => `${cardSizePrefix}${n}`
modifierCards.forEach(card => {
// remove existing size classes
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix));
card.className = classes.join(' ').trim();
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
card.className = classes.join(' ').trim()
if(val != 0)
card.classList.add(cardSize(val));
});
card.classList.add(cardSize(val))
})
}
async function loadModifiers() {
@ -1282,15 +1307,15 @@ async function loadModifiers() {
modifiers = res; // update global variable
res.forEach((modifierGroup, idx) => {
const title = modifierGroup.category;
const modifiers = modifierGroup.modifiers;
const title = modifierGroup.category
const modifiers = modifierGroup.modifiers
const titleEl = document.createElement('h5');
titleEl.className = 'collapsible';
titleEl.innerText = title;
const titleEl = document.createElement('h5')
titleEl.className = 'collapsible'
titleEl.innerText = title
const modifiersEl = document.createElement('div');
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf');
const modifiersEl = document.createElement('div')
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf')
if (idx == 0) {
titleEl.className += ' active'
@ -1298,21 +1323,21 @@ async function loadModifiers() {
}
modifiers.forEach(modObj => {
const modifierName = modObj.modifier;
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`);
const modifierName = modObj.modifier
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`)
const modifierCard = createModifierCard(modifierName, modifierPreviews);
const modifierCard = createModifierCard(modifierName, modifierPreviews)
if(typeof modifierCard == 'object') {
modifiersEl.appendChild(modifierCard);
modifiersEl.appendChild(modifierCard)
modifierCard.addEventListener('click', () => {
if (activeTags.map(x => x.name).includes(modifierName)) {
// remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName);
modifierCard.classList.remove(activeCardClass);
activeTags = activeTags.filter(x => x.name != modifierName)
modifierCard.classList.remove(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+';
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
} else {
// add modifier to active array
activeTags.push({
@ -1320,17 +1345,17 @@ async function loadModifiers() {
'element': modifierCard.cloneNode(true),
'originElement': modifierCard,
'previews': modifierPreviews
});
})
modifierCard.classList.add(activeCardClass);
modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-';
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
}
refreshTagsList();
});
refreshTagsList()
})
}
});
})
let brk = document.createElement('br')
brk.style.clear = 'both'
@ -1348,4 +1373,4 @@ async function loadModifiers() {
} catch (e) {
console.log('error fetching modifiers', e)
}
}
}

View File

@ -22,7 +22,9 @@ class Request:
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
stream_image_progress: bool = False
@ -42,6 +44,8 @@ class Request:
"sampler": self.sampler,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"output_format": self.output_format,
}
def to_string(self):
@ -62,7 +66,9 @@ class Request:
use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''

View File

@ -79,7 +79,7 @@ except:
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False):
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
ckpt_file = ckpt_to_use
@ -130,14 +130,11 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
model_is_half = True
else:
model_is_half = False
if half_model_fs:
modelFS.half()
model_is_half = True
model_fs_is_half = True
else:
model_is_half = False
model_fs_is_half = False
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
@ -208,6 +205,7 @@ def mk_img(req: Request):
})
def do_mk_img(req: Request):
global ckpt_file
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
@ -220,6 +218,15 @@ def do_mk_img(req: Request):
temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
needs_model_reload = False
if ckpt_file != req.use_stable_diffusion_model:
ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
@ -228,6 +235,7 @@ def do_mk_img(req: Request):
if model_is_half:
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device)
needs_model_reload = False
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
@ -237,17 +245,19 @@ def do_mk_img(req: Request):
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \
(req.init_image is None and model_fs_is_half) or \
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
(precision == 'full' and not req.use_full_precision and not force_full_precision):
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if needs_model_reload:
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, precision)
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
@ -274,7 +284,7 @@ def do_mk_img(req: Request):
opt_use_face_correction = req.use_face_correction
opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png'
opt_format = req.output_format
opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device)
@ -444,10 +454,10 @@ def do_mk_img(req: Request):
if return_orig_img:
save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file)
if return_orig_img:
img_data = img_to_base64_str(img)
img_data = img_to_base64_str(img, opt_format)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig)
@ -474,7 +484,7 @@ def do_mk_img(req: Request):
filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image)
filtered_img_data = img_to_base64_str(filtered_image, opt_format)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(res_image_filtered)
@ -505,8 +515,8 @@ def save_image(img, img_out_path):
except:
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}"
try:
with open(meta_out_path, 'w') as f:
@ -642,9 +652,9 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
return image
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img):
def img_to_base64_str(img, output_format="PNG"):
buffered = BytesIO()
img.save(buffered, format="PNG")
img.save(buffered, format=output_format)
buffered.seek(0)
img_byte = buffered.getvalue()
img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()

View File

@ -4,13 +4,14 @@ import traceback
import sys
import os
SCRIPT_DIR = os.getcwd()
print('started in ', SCRIPT_DIR)
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
@ -57,7 +58,9 @@ class ImageRequest(BaseModel):
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
stream_image_progress: bool = False
@ -85,9 +88,7 @@ async def ping():
from sd_internal import runtime
custom_weight_path = os.path.join(SCRIPT_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
runtime.load_model_ckpt(ckpt_to_use=ckpt_to_use)
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
model_loaded = True
model_is_loading = False
@ -97,6 +98,46 @@ async def ping():
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/image')
def image(req : ImageRequest):
from sd_internal import runtime
@ -123,10 +164,15 @@ def image(req : ImageRequest):
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model)
try:
if not req.stream_progress_updates:
r.stream_image_progress = False
@ -205,13 +251,62 @@ def getAppConfig():
return HTTPException(status_code=500, detail="No config file")
with open(config_json_path, 'r') as f:
config_json_str = f.read()
config = json.loads(config_json_str)
return config
return json.load(f)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getConfig():
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return {}
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
return {}
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
@app.get('/models')
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/modifiers.json')
def read_modifiers():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}