Merge pull request #292 from cmdr2/beta

Custom Models; Reduce RAM usage; JPG/PNG option
This commit is contained in:
cmdr2 2022-10-07 19:45:17 +05:30 committed by GitHub
commit 23d20c918f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 421 additions and 266 deletions

View File

@ -165,6 +165,9 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
@if exist "sd-v1-4.ckpt" ( @if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
echo "Data files (weights) necessary for Stable Diffusion were already downloaded. Using the HuggingFace 4 GB Model." echo "Data files (weights) necessary for Stable Diffusion were already downloaded. Using the HuggingFace 4 GB Model."

View File

@ -161,6 +161,9 @@ fi
mkdir -p "../models/stable-diffusion"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"` model_size=`find "sd-v1-4.ckpt" -printf "%s"`

View File

@ -4,7 +4,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16"> <link rel="icon" type="image/png" href="/media/favicon-16x16.png" sizes="16x16">
<link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32"> <link rel="icon" type="image/png" href="/media/favicon-32x32.png" sizes="32x32">
<link rel="stylesheet" href="/media/main.css?v=10"> <link rel="stylesheet" href="/media/main.css?v=21">
<link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1"> <link rel="stylesheet" href="/media/modifier-thumbnails.css?v=1">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css"> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
<link rel="stylesheet" href="/media/drawingboard.min.css"> <link rel="stylesheet" href="/media/drawingboard.min.css">
@ -15,7 +15,7 @@
<div id="container"> <div id="container">
<div id="top-nav"> <div id="top-nav">
<div id="logo"> <div id="logo">
<h1>Stable Diffusion UI <small>v2.195 <span id="updateBranchLabel"></span></small></h1> <h1>Stable Diffusion UI <small>v2.2 <span id="updateBranchLabel"></span></small></h1>
</div> </div>
<ul id="top-nav-items"> <ul id="top-nav-items">
<li class="dropdown"> <li class="dropdown">
@ -89,6 +89,11 @@
<li><b class="settings-subheader">Image Settings</b></li> <li><b class="settings-subheader">Image Settings</b></li>
<li class="pl-5"><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li> <li class="pl-5"><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li>
<li class="pl-5"><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="1"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1"> (images at once)</li> <li class="pl-5"><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="1"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1"> (images at once)</li>
<li class="pl-5"><label for="stable_diffusion_model">Model:</label>
<select id="stable_diffusion_model" name="stable_diffusion_model">
<!-- <option value="sd-v1-4" selected>sd-v1-4</option> -->
</select>
</li>
<li id="samplerSelection" class="pl-5"><label for="sampler">Sampler:</label> <li id="samplerSelection" class="pl-5"><label for="sampler">Sampler:</label>
<select id="sampler" name="sampler"> <select id="sampler" name="sampler">
<option value="plms" selected>plms</option> <option value="plms" selected>plms</option>
@ -150,6 +155,12 @@
<li class="pl-5"><label for="num_inference_steps">Number of inference steps:</label> <input id="num_inference_steps" name="num_inference_steps" size="4" value="50"></li> <li class="pl-5"><label for="num_inference_steps">Number of inference steps:</label> <input id="num_inference_steps" name="num_inference_steps" size="4" value="50"></li>
<li class="pl-5"><label for="guidance_scale_slider">Guidance Scale:</label> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="10" max="500"> <input id="guidance_scale" name="guidance_scale" size="4"></li> <li class="pl-5"><label for="guidance_scale_slider">Guidance Scale:</label> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="10" max="500"> <input id="guidance_scale" name="guidance_scale" size="4"></li>
<li class="pl-5"><span id="prompt_strength_container"><label for="prompt_strength_slider">Prompt Strength:</label> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4"><br/></span></li> <li class="pl-5"><span id="prompt_strength_container"><label for="prompt_strength_slider">Prompt Strength:</label> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4"><br/></span></li>
<li class="pl-5"><label for="output_format">Output format:</label>
<select id="output_format" name="output_format">
<option value="jpeg" selected>jpeg</option>
<option value="png">png</option>
</select>
</li>
<br/> <br/>
@ -213,12 +224,13 @@
</div> </div>
</body> </body>
<script src="media/main.js?v=15"></script> <script src="media/main.js?v=31"></script>
<script> <script>
async function init() { async function init() {
await loadModifiers() await loadModifiers()
await getDiskPath() await getDiskPath()
await getAppConfig() await getAppConfig()
await getModels()
setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000) setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000)
healthCheck() healthCheck()

View File

@ -70,38 +70,36 @@ label {
font-size: 8pt; font-size: 8pt;
} }
.imgSeedLabel { .imgSeedLabel {
position: absolute; font-size: 0.8em;
transform: translateX(-100%); background-color: rgb(44, 45, 48);
margin-top: 5pt; border-radius: 3px;
margin-left: -5pt; padding: 5px;
font-size: 10pt;
background-color: #333;
opacity: 0.8;
color: #ddd;
border-radius: 3pt;
padding: 1pt 3pt;
}
.imgUseBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 30pt;
margin-left: -5pt;
}
.imgSaveBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 55pt;
margin-left: -5pt;
} }
.imgItem { .imgItem {
display: inline; display: inline-block;
padding-right: 10px; margin-top: 1em;
margin-right: 1em;
}
.imgContainer {
display: flex;
justify-content: flex-end;
} }
.imgItemInfo { .imgItemInfo {
opacity: 0.5; padding-bottom: 0.5em;
display: flex;
align-items: flex-end;
flex-direction: column;
position: absolute;
padding: 5px;
opacity: 0;
transition: 0.1s all;
}
.imgContainer:hover > .imgItemInfo {
opacity: 1;
}
.imgItemInfo * {
margin-bottom: 7px;
} }
#container { #container {
width: 90%; width: 90%;
margin-left: auto; margin-left: auto;
@ -409,4 +407,7 @@ img {
font-size: 10pt; font-size: 10pt;
color: #aaa; color: #aaa;
margin-bottom: 5pt; margin-bottom: 5pt;
}
.img-batch {
display: inline;
} }

View File

@ -46,6 +46,8 @@ let samplerSelectionContainer = document.querySelector("#samplerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction") let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale") let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model") let upscaleModelField = document.querySelector("#upscale_model")
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
let outputFormatField = document.querySelector('#output_format')
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image") let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel") let updateBranchLabel = document.querySelector("#updateBranchLabel")
let streamImageProgressField = document.querySelector("#stream_image_progress") let streamImageProgressField = document.querySelector("#stream_image_progress")
@ -73,10 +75,10 @@ let editorTagsContainer = document.querySelector('#editor-inputs-tags-container'
let imagePreview = document.querySelector("#preview") let imagePreview = document.querySelector("#preview")
let previewImageField = document.querySelector('#preview-image') let previewImageField = document.querySelector('#preview-image')
previewImageField.onchange = () => changePreviewImages(previewImageField.value); previewImageField.onchange = () => changePreviewImages(previewImageField.value)
let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider') let modifierCardSizeSlider = document.querySelector('#modifier-card-size-slider')
modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value); modifierCardSizeSlider.onchange = () => resizeModifierCards(modifierCardSizeSlider.value)
// let previewPrompt = document.querySelector('#preview-prompt') // let previewPrompt = document.querySelector('#preview-prompt')
@ -119,8 +121,8 @@ let bellPending = false
let taskQueue = [] let taskQueue = []
let currentTask = null let currentTask = null
const modifierThumbnailPath = 'media/modifier-thumbnails'; const modifierThumbnailPath = 'media/modifier-thumbnails'
const activeCardClass = 'modifier-card-active'; const activeCardClass = 'modifier-card-active'
function getLocalStorageItem(key, fallback) { function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key) let item = localStorage.getItem(key)
@ -202,7 +204,7 @@ function isStreamImageProgressEnabled() {
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') { if (statusType !== 'server') {
return; return
} }
if (msgType == 'error') { if (msgType == 'error') {
@ -259,18 +261,100 @@ async function healthCheck() {
} }
} }
function makeImageElement(width, height, outputContainer) { function showImages(req, res, outputContainer, livePreview) {
let imgItem = document.createElement('div') let imageItemElements = outputContainer.querySelectorAll('.imgItem')
imgItem.className = 'imgItem'
let img = document.createElement('img') res.output.reverse()
img.width = parseInt(width)
img.height = parseInt(height)
imgItem.appendChild(img) res.output.forEach((result, index) => {
outputContainer.insertBefore(imgItem, outputContainer.firstChild) if(typeof res != 'object') return
return imgItem const imageData = result?.data || result?.path + '?t=' + new Date().getTime(),
imageSeed = result?.seed,
imageWidth = req.width,
imageHeight = req.height;
if (!imageData.includes('/')) {
// res contained no data for the image, stop execution
setStatus('request', 'invalid image', 'error')
return
}
let imageItemElem = (index < imageItemElements.length ? imageItemElements[index] : null)
if(!imageItemElem) {
imageItemElem = document.createElement('div')
imageItemElem.className = 'imgItem'
imageItemElem.innerHTML = `
<div class="imgContainer">
<img/>
<div class="imgItemInfo">
<span class="imgSeedLabel"></span>
<button class="imgUseBtn">Use as Input</button>
<button class="imgSaveBtn">Download</button>
</div>
</div>
`
const useAsInputBtn = imageItemElem.querySelector('.imgUseBtn'),
saveImageBtn = imageItemElem.querySelector('.imgSaveBtn');
useAsInputBtn.addEventListener('click', getUseAsInputHandler(imageItemElem))
saveImageBtn.addEventListener('click', getSaveImageHandler(imageItemElem, req['output_format']))
outputContainer.appendChild(imageItemElem)
}
const imageElem = imageItemElem.querySelector('img'),
imageSeedLabel = imageItemElem.querySelector('.imgSeedLabel');
imageElem.src = imageData
imageElem.width = parseInt(imageWidth)
imageElem.height = parseInt(imageHeight)
imageElem.setAttribute('data-seed', imageSeed)
const imageInfo = imageItemElem.querySelector('.imgItemInfo')
imageInfo.style.visibility = (livePreview ? 'hidden' : 'visible')
imageSeedLabel.innerText = 'Seed: ' + imageSeed
})
}
function getUseAsInputHandler(imageItemElem) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
initImageSelector.value = null
initImagePreview.src = imgData
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
samplerSelectionContainer.style.display = 'none'
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = imageSeed
seedField.disabled = false
}
}
function getSaveImageHandler(imageItemElem, outputFormat) {
return function() {
const imageElem = imageItemElem.querySelector('img')
const imgData = imageElem.src
const imageSeed = imageElem.getAttribute('data-seed')
const imgDownload = document.createElement('a')
imgDownload.download = createFileName(imageSeed, outputFormat)
imgDownload.href = imgData
imgDownload.click()
}
} }
// makes a single image. don't call this directly, use makeImage() instead // makes a single image. don't call this directly, use makeImage() instead
@ -281,7 +365,10 @@ async function doMakeImage(task) {
const reqBody = task.reqBody const reqBody = task.reqBody
const batchCount = task.batchCount const batchCount = task.batchCount
const outputContainer = task.outputContainer const outputContainer = document.createElement('div')
outputContainer.className = 'img-batch'
task.outputContainer.insertBefore(outputContainer, task.outputContainer.firstChild)
const outputMsg = task['outputMsg'] const outputMsg = task['outputMsg']
const previewPrompt = task['previewPrompt'] const previewPrompt = task['previewPrompt']
@ -291,14 +378,6 @@ async function doMakeImage(task) {
let seed = reqBody['seed'] let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs']) let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers(numImages) {
for (let i = images.length; i < numImages; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height, outputContainer))
}
}
try { try {
res = await fetch('/image', { res = await fetch('/image', {
method: 'POST', method: 'POST',
@ -351,14 +430,7 @@ async function doMakeImage(task) {
outputMsg.style.display = 'block' outputMsg.style.display = 'block'
if (stepUpdate.output !== undefined) { if (stepUpdate.output !== undefined) {
makeImageContainers(numOutputs) showImages(reqBody, stepUpdate, outputContainer, true)
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
} }
} }
} catch (e) { } catch (e) {
@ -426,85 +498,11 @@ async function doMakeImage(task) {
res = undefined res = undefined
} }
if (!res) { if (!res) return false
return false
}
lastPromptUsed = reqBody['prompt'] lastPromptUsed = reqBody['prompt']
makeImageContainers(res.output.length) showImages(reqBody, res, outputContainer, false)
for (let idx in res.output) {
let imgBody = ''
let seed = 0
try {
let imgData = res.output[idx]
imgBody = imgData.data
seed = imgData.seed
} catch (e) {
console.log(imgBody)
setStatus('request', 'invalid image', 'error')
continue
}
let imgItem = images[idx]
let img = imgItem.firstChild
img.src = imgBody
let imgItemInfo = document.createElement('span')
imgItemInfo.className = 'imgItemInfo'
imgItemInfo.style.opacity = 0
let imgSeedLabel = document.createElement('span')
imgSeedLabel.className = 'imgSeedLabel'
imgSeedLabel.innerText = 'Seed: ' + seed
let imgUseBtn = document.createElement('button')
imgUseBtn.className = 'imgUseBtn'
imgUseBtn.innerText = 'Use as Input'
let imgSaveBtn = document.createElement('button')
imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerText = 'Download'
imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = imgBody
initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block'
maskSetting.checked = false
// maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
seedField.disabled = false
})
imgSaveBtn.addEventListener('click', function() {
let imgDownload = document.createElement('a')
imgDownload.download = createFileName();
imgDownload.href = imgBody
imgDownload.click()
})
imgItem.addEventListener('mouseenter', function() {
imgItemInfo.style.opacity = 1
})
imgItem.addEventListener('mouseleave', function() {
imgItemInfo.style.opacity = 0
})
}
return true return true
} }
@ -547,7 +545,6 @@ async function checkTasks() {
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop' task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing" task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].className += " activeTaskLabel" task['taskStatusLabel'].className += " activeTaskLabel"
console.log(task['taskStatusLabel'].className)
for (let i = 0; i < task.batchCount; i++) { for (let i = 0; i < task.batchCount; i++) {
task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs']) task.reqBody['seed'] = task.seed + (i * task.reqBody['num_outputs'])
@ -612,8 +609,8 @@ async function makeImage() {
let prompt = promptField.value let prompt = promptField.value
if (activeTags.length > 0) { if (activeTags.length > 0) {
let promptTags = activeTags.map(x => x.name).join(", "); let promptTags = activeTags.map(x => x.name).join(", ")
prompt += ", " + promptTags; prompt += ", " + promptTags
} }
let reqBody = { let reqBody = {
@ -629,9 +626,11 @@ async function makeImage() {
turbo: turboField.checked, turbo: turboField.checked,
use_cpu: useCPUField.checked, use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked, use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value,
stream_progress_updates: true, stream_progress_updates: true,
stream_image_progress: streamImageProgress, stream_image_progress: streamImageProgress,
show_only_filtered_image: showOnlyFilteredImageField.checked show_only_filtered_image: showOnlyFilteredImageField.checked,
output_format: outputFormatField.value
} }
if (IMAGE_REGEX.test(initImagePreview.src)) { if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -662,7 +661,7 @@ async function makeImage() {
reqBody['use_upscale'] = upscaleModelField.value reqBody['use_upscale'] = upscaleModelField.value
} }
let taskConfig = `Seed: ${seed}, Sampler: ${reqBody['sampler']}, Inference Steps: ${numInferenceStepsField.value}, Guidance Scale: ${guidanceScaleField.value}` let taskConfig = `Seed: ${seed}, Sampler: ${reqBody['sampler']}, Inference Steps: ${numInferenceStepsField.value}, Guidance Scale: ${guidanceScaleField.value}, Model: ${stableDiffusionModelField.value}`
if (negativePromptField.value.trim() !== '') { if (negativePromptField.value.trim() !== '') {
taskConfig += `, Negative Prompt: ${negativePromptField.value.trim()}` taskConfig += `, Negative Prompt: ${negativePromptField.value.trim()}`
@ -736,12 +735,11 @@ async function makeImage() {
// create a file name with embedded prompt and metadata // create a file name with embedded prompt and metadata
// for easier cateloging and comparison // for easier cateloging and comparison
function createFileName() { function createFileName(seed, outputFormat) {
// Most important information is the prompt // Most important information is the prompt
let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_') let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_')
underscoreName = underscoreName.substring(0, 100) underscoreName = underscoreName.substring(0, 100)
const seed = seedField.value
const steps = numInferenceStepsField.value const steps = numInferenceStepsField.value
const guidance = guidanceScaleField.value const guidance = guidanceScaleField.value
@ -749,23 +747,23 @@ function createFileName() {
let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}` let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}`
// add the tags // add the tags
// let tags = []; // let tags = []
// let tagString = ''; // let tagString = ''
// document.querySelectorAll(modifyTagsSelector).forEach(function(tag) { // document.querySelectorAll(modifyTagsSelector).forEach(function(tag) {
// tags.push(tag.innerHTML); // tags.push(tag.innerHTML)
// }) // })
// join the tags with a pipe // join the tags with a pipe
// if (activeTags.length > 0) { // if (activeTags.length > 0) {
// tagString = '_Tags-'; // tagString = '_Tags-'
// tagString += tags.join('|'); // tagString += tags.join('|')
// } // }
// // append empty or populated tags // // append empty or populated tags
// fileName += `${tagString}`; // fileName += `${tagString}`
// add the file extension // add the file extension
fileName += `.png` fileName += '.' + (outputFormat === 'png' ? 'png' : 'jpeg')
return fileName return fileName
} }
@ -939,6 +937,33 @@ async function getAppConfig() {
} }
} }
async function getModels() {
try {
let res = await fetch('/models')
models = await res.json()
let activeModel = models['active']
let modelOptions = models['options']
let stableDiffusionOptions = modelOptions['stable-diffusion']
stableDiffusionOptions.forEach(modelName => {
let modelOption = document.createElement('option')
modelOption.value = modelName
modelOption.innerText = modelName
if (modelName === activeModel['stable-diffusion']) {
modelOption.selected = true
}
stableDiffusionModelField.appendChild(modelOption)
})
console.log('get models response', config)
} catch (e) {
console.log('get models error', e)
}
}
function checkRandomSeed() { function checkRandomSeed() {
if (randomSeedField.checked) { if (randomSeedField.checked) {
seedField.disabled = true seedField.disabled = true
@ -1037,25 +1062,25 @@ maskSetting.addEventListener('click', function() {
// https://stackoverflow.com/a/8212878 // https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) { function millisecondsToStr(milliseconds) {
function numberEnding (number) { function numberEnding (number) {
return (number > 1) ? 's' : ''; return (number > 1) ? 's' : ''
} }
var temp = Math.floor(milliseconds / 1000); var temp = Math.floor(milliseconds / 1000)
var hours = Math.floor((temp %= 86400) / 3600); var hours = Math.floor((temp %= 86400) / 3600)
var s = '' var s = ''
if (hours) { if (hours) {
s += hours + ' hour' + numberEnding(hours) + ' '; s += hours + ' hour' + numberEnding(hours) + ' '
} }
var minutes = Math.floor((temp %= 3600) / 60); var minutes = Math.floor((temp %= 3600) / 60)
if (minutes) { if (minutes) {
s += minutes + ' minute' + numberEnding(minutes) + ' '; s += minutes + ' minute' + numberEnding(minutes) + ' '
} }
var seconds = temp % 60; var seconds = temp % 60
if (!hours && minutes < 4 && seconds) { if (!hours && minutes < 4 && seconds) {
s += seconds + ' second' + numberEnding(seconds); s += seconds + ' second' + numberEnding(seconds)
} }
return s; return s
} }
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/ // https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
@ -1115,33 +1140,33 @@ function createCollapsibles(node) {
createCollapsibles() createCollapsibles()
function refreshTagsList() { function refreshTagsList() {
editorModifierTagsList.innerHTML = ''; editorModifierTagsList.innerHTML = ''
if (activeTags.length == 0) { if (activeTags.length == 0) {
editorTagsContainer.style.display = 'none'; editorTagsContainer.style.display = 'none'
return; return
} else { } else {
editorTagsContainer.style.display = 'block'; editorTagsContainer.style.display = 'block'
} }
activeTags.forEach((tag, index) => { activeTags.forEach((tag, index) => {
tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'; tag.element.querySelector('.modifier-card-image-overlay').innerText = '-'
tag.element.classList.add('modifier-card-tiny'); tag.element.classList.add('modifier-card-tiny')
editorModifierTagsList.appendChild(tag.element); editorModifierTagsList.appendChild(tag.element)
tag.element.addEventListener('click', () => { tag.element.addEventListener('click', () => {
let idx = activeTags.indexOf(tag); let idx = activeTags.indexOf(tag)
if (idx !== -1) { if (idx !== -1) {
activeTags[idx].originElement.classList.remove(activeCardClass); activeTags[idx].originElement.classList.remove(activeCardClass)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'; activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
activeTags.splice(idx, 1); activeTags.splice(idx, 1)
refreshTagsList(); refreshTagsList()
} }
}); })
}); })
let brk = document.createElement('br') let brk = document.createElement('br')
brk.style.clear = 'both' brk.style.clear = 'both'
@ -1170,8 +1195,8 @@ async function getDiskPath() {
} }
function createModifierCard(name, previews) { function createModifierCard(name, previews) {
const modifierCard = document.createElement('div'); const modifierCard = document.createElement('div')
modifierCard.className = 'modifier-card'; modifierCard.className = 'modifier-card'
modifierCard.innerHTML = ` modifierCard.innerHTML = `
<div class="modifier-card-overlay"></div> <div class="modifier-card-overlay"></div>
<div class="modifier-card-image-container"> <div class="modifier-card-image-container">
@ -1181,96 +1206,96 @@ function createModifierCard(name, previews) {
</div> </div>
<div class="modifier-card-container"> <div class="modifier-card-container">
<div class="modifier-card-label"><p></p></div> <div class="modifier-card-label"><p></p></div>
</div>`; </div>`
const image = modifierCard.querySelector('.modifier-card-image'); const image = modifierCard.querySelector('.modifier-card-image')
const errorText = modifierCard.querySelector('.modifier-card-error-label'); const errorText = modifierCard.querySelector('.modifier-card-error-label')
const label = modifierCard.querySelector('.modifier-card-label'); const label = modifierCard.querySelector('.modifier-card-label')
errorText.innerText = 'No Image'; errorText.innerText = 'No Image'
if (typeof previews == 'object') { if (typeof previews == 'object') {
image.src = previews[0]; // portrait image.src = previews[0]; // portrait
image.setAttribute('preview-type', 'portrait'); image.setAttribute('preview-type', 'portrait')
} else { } else {
image.remove(); image.remove()
} }
const maxLabelLength = 30; const maxLabelLength = 30
const nameWithoutBy = name.replace('by ', ''); const nameWithoutBy = name.replace('by ', '')
if(nameWithoutBy.length <= maxLabelLength) { if(nameWithoutBy.length <= maxLabelLength) {
label.querySelector('p').innerText = nameWithoutBy; label.querySelector('p').innerText = nameWithoutBy
} else { } else {
const tooltipText = document.createElement('span'); const tooltipText = document.createElement('span')
tooltipText.className = 'tooltip-text'; tooltipText.className = 'tooltip-text'
tooltipText.innerText = name; tooltipText.innerText = name
label.classList.add('tooltip'); label.classList.add('tooltip')
label.appendChild(tooltipText); label.appendChild(tooltipText)
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'; label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'
} }
return modifierCard; return modifierCard
} }
function changePreviewImages(val) { function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img'); const previewImages = document.querySelectorAll('.modifier-card-image-container img')
let previewArr = []; let previewArr = []
modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews))); modifiers.map(x => x.modifiers).forEach(x => previewArr.push(...x.map(m => m.previews)))
previewArr = previewArr.map(x => { previewArr = previewArr.map(x => {
let obj = {}; let obj = {}
x.forEach(preview => { x.forEach(preview => {
obj[preview.name] = preview.path; obj[preview.name] = preview.path
}); })
return obj; return obj
}); })
previewImages.forEach(previewImage => { previewImages.forEach(previewImage => {
const currentPreviewType = previewImage.getAttribute('preview-type'); const currentPreviewType = previewImage.getAttribute('preview-type')
const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop(); const relativePreviewPath = previewImage.src.split(modifierThumbnailPath + '/').pop()
const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType]); const previews = previewArr.find(preview => relativePreviewPath == preview[currentPreviewType])
if(typeof previews == 'object') { if(typeof previews == 'object') {
let preview = null; let preview = null
if (val == 'portrait') { if (val == 'portrait') {
preview = previews.portrait; preview = previews.portrait
} }
else if (val == 'landscape') { else if (val == 'landscape') {
preview = previews.landscape; preview = previews.landscape
} }
if(preview != null) { if(preview != null) {
previewImage.src = `${modifierThumbnailPath}/${preview}`; previewImage.src = `${modifierThumbnailPath}/${preview}`
previewImage.setAttribute('preview-type', val); previewImage.setAttribute('preview-type', val)
} }
} }
}); })
} }
function resizeModifierCards(val) { function resizeModifierCards(val) {
const cardSizePrefix = 'modifier-card-size_'; const cardSizePrefix = 'modifier-card-size_'
const modifierCardClass = 'modifier-card'; const modifierCardClass = 'modifier-card'
const modifierCards = document.querySelectorAll(`.${modifierCardClass}`); const modifierCards = document.querySelectorAll(`.${modifierCardClass}`)
const cardSize = n => `${cardSizePrefix}${n}`; const cardSize = n => `${cardSizePrefix}${n}`
modifierCards.forEach(card => { modifierCards.forEach(card => {
// remove existing size classes // remove existing size classes
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix)); const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
card.className = classes.join(' ').trim(); card.className = classes.join(' ').trim()
if(val != 0) if(val != 0)
card.classList.add(cardSize(val)); card.classList.add(cardSize(val))
}); })
} }
async function loadModifiers() { async function loadModifiers() {
@ -1282,15 +1307,15 @@ async function loadModifiers() {
modifiers = res; // update global variable modifiers = res; // update global variable
res.forEach((modifierGroup, idx) => { res.forEach((modifierGroup, idx) => {
const title = modifierGroup.category; const title = modifierGroup.category
const modifiers = modifierGroup.modifiers; const modifiers = modifierGroup.modifiers
const titleEl = document.createElement('h5'); const titleEl = document.createElement('h5')
titleEl.className = 'collapsible'; titleEl.className = 'collapsible'
titleEl.innerText = title; titleEl.innerText = title
const modifiersEl = document.createElement('div'); const modifiersEl = document.createElement('div')
modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf'); modifiersEl.classList.add('collapsible-content', 'editor-modifiers-leaf')
if (idx == 0) { if (idx == 0) {
titleEl.className += ' active' titleEl.className += ' active'
@ -1298,21 +1323,21 @@ async function loadModifiers() {
} }
modifiers.forEach(modObj => { modifiers.forEach(modObj => {
const modifierName = modObj.modifier; const modifierName = modObj.modifier
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`); const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`)
const modifierCard = createModifierCard(modifierName, modifierPreviews); const modifierCard = createModifierCard(modifierName, modifierPreviews)
if(typeof modifierCard == 'object') { if(typeof modifierCard == 'object') {
modifiersEl.appendChild(modifierCard); modifiersEl.appendChild(modifierCard)
modifierCard.addEventListener('click', () => { modifierCard.addEventListener('click', () => {
if (activeTags.map(x => x.name).includes(modifierName)) { if (activeTags.map(x => x.name).includes(modifierName)) {
// remove modifier from active array // remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName); activeTags = activeTags.filter(x => x.name != modifierName)
modifierCard.classList.remove(activeCardClass); modifierCard.classList.remove(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'; modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
} else { } else {
// add modifier to active array // add modifier to active array
activeTags.push({ activeTags.push({
@ -1320,17 +1345,17 @@ async function loadModifiers() {
'element': modifierCard.cloneNode(true), 'element': modifierCard.cloneNode(true),
'originElement': modifierCard, 'originElement': modifierCard,
'previews': modifierPreviews 'previews': modifierPreviews
}); })
modifierCard.classList.add(activeCardClass); modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'; modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
} }
refreshTagsList(); refreshTagsList()
}); })
} }
}); })
let brk = document.createElement('br') let brk = document.createElement('br')
brk.style.clear = 'both' brk.style.clear = 'both'
@ -1348,4 +1373,4 @@ async function loadModifiers() {
} catch (e) { } catch (e) {
console.log('error fetching modifiers', e) console.log('error fetching modifiers', e)
} }
} }

View File

@ -22,7 +22,9 @@ class Request:
use_full_precision: bool = False use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
@ -42,6 +44,8 @@ class Request:
"sampler": self.sampler, "sampler": self.sampler,
"use_face_correction": self.use_face_correction, "use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale, "use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"output_format": self.output_format,
} }
def to_string(self): def to_string(self):
@ -62,7 +66,9 @@ class Request:
use_full_precision: {self.use_full_precision} use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction} use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale} use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
show_only_filtered_image: {self.show_only_filtered_image} show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
stream_progress_updates: {self.stream_progress_updates} stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}''' stream_image_progress: {self.stream_image_progress}'''

View File

@ -79,7 +79,7 @@ except:
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!') print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass pass
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False): def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
ckpt_file = ckpt_to_use ckpt_file = ckpt_to_use
@ -130,14 +130,11 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
if device != "cpu" and precision == "autocast": if device != "cpu" and precision == "autocast":
model.half() model.half()
modelCS.half() modelCS.half()
model_is_half = True
else:
model_is_half = False
if half_model_fs:
modelFS.half() modelFS.half()
model_is_half = True
model_fs_is_half = True model_fs_is_half = True
else: else:
model_is_half = False
model_fs_is_half = False model_fs_is_half = False
print('loaded ', ckpt_file, 'to', device, 'precision', precision) print('loaded ', ckpt_file, 'to', device, 'precision', precision)
@ -208,6 +205,7 @@ def mk_img(req: Request):
}) })
def do_mk_img(req: Request): def do_mk_img(req: Request):
global ckpt_file
global model, modelCS, modelFS, device global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan global model_gfpgan, model_real_esrgan
global stop_processing global stop_processing
@ -220,6 +218,15 @@ def do_mk_img(req: Request):
temp_images.clear() temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
needs_model_reload = False
if ckpt_file != req.use_stable_diffusion_model:
ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
model.turbo = req.turbo model.turbo = req.turbo
if req.use_cpu: if req.use_cpu:
if device != 'cpu': if device != 'cpu':
@ -228,6 +235,7 @@ def do_mk_img(req: Request):
if model_is_half: if model_is_half:
del model, modelCS, modelFS del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device) load_model_ckpt(ckpt_file, device)
needs_model_reload = False
load_model_gfpgan(gfpgan_file) load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file) load_model_real_esrgan(real_esrgan_file)
@ -237,17 +245,19 @@ def do_mk_img(req: Request):
device = 'cuda' device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \ if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \ (precision == 'full' and not req.use_full_precision and not force_full_precision):
(req.init_image is None and model_fs_is_half) or \
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
del model, modelCS, modelFS del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision)) load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
needs_model_reload = False
if prev_device != device: if prev_device != device:
load_model_gfpgan(gfpgan_file) load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file) load_model_real_esrgan(real_esrgan_file)
if needs_model_reload:
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, precision)
if req.use_face_correction != gfpgan_file: if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction) load_model_gfpgan(req.use_face_correction)
@ -274,7 +284,7 @@ def do_mk_img(req: Request):
opt_use_face_correction = req.use_face_correction opt_use_face_correction = req.use_face_correction
opt_use_upscale = req.use_upscale opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png' opt_format = req.output_format
opt_sampler_name = req.sampler opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device) print(req.to_string(), '\n device', device)
@ -444,10 +454,10 @@ def do_mk_img(req: Request):
if return_orig_img: if return_orig_img:
save_image(img, img_out_path) save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt) save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file)
if return_orig_img: if return_orig_img:
img_data = img_to_base64_str(img) img_data = img_to_base64_str(img, opt_format)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed) res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig) res.images.append(res_image_orig)
@ -474,7 +484,7 @@ def do_mk_img(req: Request):
filtered_image = Image.fromarray(x_sample) filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image) filtered_img_data = img_to_base64_str(filtered_image, opt_format)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed) res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(res_image_filtered) res.images.append(res_image_filtered)
@ -505,8 +515,8 @@ def save_image(img, img_out_path):
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt): def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}" metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}"
try: try:
with open(meta_out_path, 'w') as f: with open(meta_out_path, 'w') as f:
@ -642,9 +652,9 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
return image return image
# https://stackoverflow.com/a/61114178 # https://stackoverflow.com/a/61114178
def img_to_base64_str(img): def img_to_base64_str(img, output_format="PNG"):
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format="PNG") img.save(buffered, format=output_format)
buffered.seek(0) buffered.seek(0)
img_byte = buffered.getvalue() img_byte = buffered.getvalue()
img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode() img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()

View File

@ -4,13 +4,14 @@ import traceback
import sys import sys
import os import os
SCRIPT_DIR = os.getcwd() SD_DIR = os.getcwd()
print('started in ', SCRIPT_DIR) print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None) SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR)) sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts') CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
@ -57,7 +58,9 @@ class ImageRequest(BaseModel):
use_full_precision: bool = False use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
@ -85,9 +88,7 @@ async def ping():
from sd_internal import runtime from sd_internal import runtime
custom_weight_path = os.path.join(SCRIPT_DIR, 'custom-model.ckpt') runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
runtime.load_model_ckpt(ckpt_to_use=ckpt_to_use)
model_loaded = True model_loaded = True
model_is_loading = False model_is_loading = False
@ -97,6 +98,46 @@ async def ping():
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/image') @app.post('/image')
def image(req : ImageRequest): def image(req : ImageRequest):
from sd_internal import runtime from sd_internal import runtime
@ -123,10 +164,15 @@ def image(req : ImageRequest):
r.use_upscale: str = req.use_upscale r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress r.stream_image_progress = req.stream_image_progress
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model)
try: try:
if not req.stream_progress_updates: if not req.stream_progress_updates:
r.stream_image_progress = False r.stream_image_progress = False
@ -205,13 +251,62 @@ def getAppConfig():
return HTTPException(status_code=500, detail="No config file") return HTTPException(status_code=500, detail="No config file")
with open(config_json_path, 'r') as f: with open(config_json_path, 'r') as f:
config_json_str = f.read() return json.load(f)
config = json.loads(config_json_str)
return config
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
def getConfig():
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return {}
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
return {}
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
@app.get('/models')
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/modifiers.json') @app.get('/modifiers.json')
def read_modifiers(): def read_modifiers():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}