forked from extern/easydiffusion
Merge pull request #346 from cmdr2/task_manager
Allow multiple tabs and computers to generate tasks without throwing errors (by @madrang)
This commit is contained in:
commit
1c171d0f12
@ -16,7 +16,7 @@
|
|||||||
<div id="container">
|
<div id="container">
|
||||||
<div id="top-nav">
|
<div id="top-nav">
|
||||||
<div id="logo">
|
<div id="logo">
|
||||||
<h1>Stable Diffusion UI <small>v2.26 <span id="updateBranchLabel"></span></small></h1>
|
<h1>Stable Diffusion UI <small>v2.27 <span id="updateBranchLabel"></span></small></h1>
|
||||||
</div>
|
</div>
|
||||||
<ul id="top-nav-items">
|
<ul id="top-nav-items">
|
||||||
<li class="dropdown">
|
<li class="dropdown">
|
||||||
|
@ -326,13 +326,13 @@ img {
|
|||||||
height: 8pt;
|
height: 8pt;
|
||||||
border-radius: 4pt; */
|
border-radius: 4pt; */
|
||||||
font-size: 14pt;
|
font-size: 14pt;
|
||||||
color: rgb(128, 87, 0);
|
color: rgb(200, 139, 0);
|
||||||
/* background-color: rgb(197, 1, 1); */
|
/* background-color: rgb(197, 1, 1); */
|
||||||
/* transform: translateY(15%); */
|
/* transform: translateY(15%); */
|
||||||
display: inline;
|
display: inline;
|
||||||
}
|
}
|
||||||
#server-status-msg {
|
#server-status-msg {
|
||||||
color: rgb(128, 87, 0);
|
color: rgb(200, 139, 0);
|
||||||
padding-left: 2pt;
|
padding-left: 2pt;
|
||||||
font-size: 10pt;
|
font-size: 10pt;
|
||||||
}
|
}
|
||||||
@ -479,7 +479,12 @@ img {
|
|||||||
.activeTaskLabel {
|
.activeTaskLabel {
|
||||||
background:rgb(0, 90, 30);
|
background:rgb(0, 90, 30);
|
||||||
border: 1px solid rgb(0, 75, 19);
|
border: 1px solid rgb(0, 75, 19);
|
||||||
color:rgb(204, 255, 217)
|
color:rgb(222, 253, 230)
|
||||||
|
}
|
||||||
|
.waitingTaskLabel {
|
||||||
|
background:rgb(128, 89, 0);
|
||||||
|
border: 1px solid rgb(0, 75, 19);
|
||||||
|
color:rgb(255, 242, 211)
|
||||||
}
|
}
|
||||||
.secondaryButton {
|
.secondaryButton {
|
||||||
background: rgb(132, 8, 0);
|
background: rgb(132, 8, 0);
|
||||||
|
205
ui/media/main.js
205
ui/media/main.js
@ -20,7 +20,7 @@ const INPAINTING_EDITOR_SIZE = 450
|
|||||||
|
|
||||||
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
|
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
|
||||||
|
|
||||||
let sessionId = new Date().getTime()
|
let sessionId = Date.now()
|
||||||
|
|
||||||
let promptField = document.querySelector('#prompt')
|
let promptField = document.querySelector('#prompt')
|
||||||
let promptsFromFileSelector = document.querySelector('#prompt_from_file')
|
let promptsFromFileSelector = document.querySelector('#prompt_from_file')
|
||||||
@ -122,7 +122,7 @@ maskResetButton.innerHTML = 'Clear'
|
|||||||
maskResetButton.style.fontWeight = 'normal'
|
maskResetButton.style.fontWeight = 'normal'
|
||||||
maskResetButton.style.fontSize = '10pt'
|
maskResetButton.style.fontSize = '10pt'
|
||||||
|
|
||||||
let serverStatus = 'offline'
|
let serverState = {'status': 'Offline', 'time': Date.now()}
|
||||||
let activeTags = []
|
let activeTags = []
|
||||||
let modifiers = []
|
let modifiers = []
|
||||||
let lastPromptUsed = ''
|
let lastPromptUsed = ''
|
||||||
@ -225,21 +225,38 @@ function getOutputFormat() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function setStatus(statusType, msg, msgType) {
|
function setStatus(statusType, msg, msgType) {
|
||||||
if (statusType !== 'server') {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (msgType == 'error') {
|
function setServerStatus(msgType, msg) {
|
||||||
// msg = '<span style="color: red">' + msg + '<span>'
|
switch(msgType) {
|
||||||
|
case 'online':
|
||||||
|
serverStatusColor.style.color = 'green'
|
||||||
|
serverStatusMsg.style.color = 'green'
|
||||||
|
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
|
||||||
|
break
|
||||||
|
case 'busy':
|
||||||
|
serverStatusColor.style.color = 'rgb(200, 139, 0)'
|
||||||
|
serverStatusMsg.style.color = 'rgb(200, 139, 0)'
|
||||||
|
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
|
||||||
|
break
|
||||||
|
case 'error':
|
||||||
serverStatusColor.style.color = 'red'
|
serverStatusColor.style.color = 'red'
|
||||||
serverStatusMsg.style.color = 'red'
|
serverStatusMsg.style.color = 'red'
|
||||||
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
|
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
|
||||||
} else if (msgType == 'success') {
|
break
|
||||||
// msg = '<span style="color: green">' + msg + '<span>'
|
}
|
||||||
serverStatusColor.style.color = 'green'
|
}
|
||||||
serverStatusMsg.style.color = 'green'
|
function isServerAvailable() {
|
||||||
serverStatusMsg.innerText = 'Stable Diffusion is ready'
|
if (typeof serverState !== 'object') {
|
||||||
serverStatus = 'online'
|
return false
|
||||||
|
}
|
||||||
|
switch (serverState.status) {
|
||||||
|
case 'LoadingModel':
|
||||||
|
case 'Rendering':
|
||||||
|
case 'Online':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +280,11 @@ function logError(msg, res, outputMsg) {
|
|||||||
console.log('request error', res)
|
console.log('request error', res)
|
||||||
setStatus('request', 'error', 'error')
|
setStatus('request', 'error', 'error')
|
||||||
}
|
}
|
||||||
|
function asyncDelay(timeout) {
|
||||||
|
return new Promise(function(resolve, reject) {
|
||||||
|
setTimeout(resolve, timeout, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
function playSound() {
|
function playSound() {
|
||||||
const audio = new Audio('/media/ding.mp3')
|
const audio = new Audio('/media/ding.mp3')
|
||||||
@ -277,16 +299,40 @@ function playSound() {
|
|||||||
|
|
||||||
async function healthCheck() {
|
async function healthCheck() {
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/ping')
|
let res = undefined
|
||||||
res = await res.json()
|
if (sessionId) {
|
||||||
|
res = await fetch('/ping?session_id=' + sessionId)
|
||||||
if (res[0] == 'OK') {
|
|
||||||
setStatus('server', 'online', 'success')
|
|
||||||
} else {
|
} else {
|
||||||
setStatus('server', 'offline', 'error')
|
res = await fetch('/ping')
|
||||||
}
|
}
|
||||||
|
serverState = await res.json()
|
||||||
|
if (typeof serverState !== 'object' || typeof serverState.status !== 'string') {
|
||||||
|
serverState = {'status': 'Offline', 'time': Date.now()}
|
||||||
|
setServerStatus('error', 'offline')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Set status
|
||||||
|
switch(serverState.status) {
|
||||||
|
case 'Init':
|
||||||
|
// Wait for init to complete before updating status.
|
||||||
|
break
|
||||||
|
case 'Online':
|
||||||
|
setServerStatus('online', 'ready')
|
||||||
|
break
|
||||||
|
case 'LoadingModel':
|
||||||
|
setServerStatus('busy', 'loading model')
|
||||||
|
break
|
||||||
|
case 'Rendering':
|
||||||
|
setServerStatus('busy', 'rendering')
|
||||||
|
break
|
||||||
|
default: // Unavailable
|
||||||
|
setServerStatus('error', serverState.status.toLowerCase())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
serverState.time = Date.now()
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
setStatus('server', 'offline', 'error')
|
serverState = {'status': 'Offline', 'time': Date.now()}
|
||||||
|
setServerStatus('error', 'offline')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
function resizeInpaintingEditor() {
|
function resizeInpaintingEditor() {
|
||||||
@ -329,7 +375,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
|||||||
if(typeof res != 'object') return
|
if(typeof res != 'object') return
|
||||||
res.output.reverse()
|
res.output.reverse()
|
||||||
res.output.forEach((result, index) => {
|
res.output.forEach((result, index) => {
|
||||||
const imageData = result?.data || result?.path + '?t=' + new Date().getTime(),
|
const imageData = result?.data || result?.path + '?t=' + Date.now(),
|
||||||
imageSeed = result?.seed,
|
imageSeed = result?.seed,
|
||||||
imagePrompt = reqBody.prompt,
|
imagePrompt = reqBody.prompt,
|
||||||
imageInferenceSteps = reqBody.num_inference_steps,
|
imageInferenceSteps = reqBody.num_inference_steps,
|
||||||
@ -440,8 +486,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) {
|
|||||||
}
|
}
|
||||||
function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
|
function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
|
||||||
return function() {
|
return function() {
|
||||||
if (serverStatus !== 'online') {
|
if (!isServerAvailable()) {
|
||||||
alert('The server is still starting up..')
|
alert('The server is not available.')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const imageElem = imageItemElem.querySelector('img')
|
const imageElem = imageItemElem.querySelector('img')
|
||||||
@ -507,37 +553,76 @@ async function doMakeImage(task) {
|
|||||||
const progressBar = task['progressBar']
|
const progressBar = task['progressBar']
|
||||||
|
|
||||||
let res = undefined
|
let res = undefined
|
||||||
let stepUpdate = undefined
|
|
||||||
try {
|
try {
|
||||||
res = await fetch('/image', {
|
const lastTask = serverState.task
|
||||||
|
let renderRequest = undefined
|
||||||
|
do {
|
||||||
|
res = await fetch('/render', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify(reqBody)
|
body: JSON.stringify(reqBody)
|
||||||
})
|
})
|
||||||
|
renderRequest = await res.json()
|
||||||
|
// status_code 503, already a task running.
|
||||||
|
} while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000))
|
||||||
|
if (typeof renderRequest?.stream !== 'string') {
|
||||||
|
console.log('Endpoint response: ', renderRequest)
|
||||||
|
throw new Error('Endpoint response does not contains a response stream url.')
|
||||||
|
}
|
||||||
|
task['taskStatusLabel'].innerText = "Waiting"
|
||||||
|
task['taskStatusLabel'].classList.add('waitingTaskLabel')
|
||||||
|
task['taskStatusLabel'].classList.remove('activeTaskLabel')
|
||||||
|
|
||||||
|
do { // Wait for server status to update.
|
||||||
|
await asyncDelay(250)
|
||||||
|
if (!isServerAvailable()) {
|
||||||
|
throw new Error('Connexion with server lost.')
|
||||||
|
}
|
||||||
|
} while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task)
|
||||||
|
if (serverState.session !== 'pending' && serverState.session !== 'running' && serverState.session !== 'buffer') {
|
||||||
|
if (serverState.session === 'stopped') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
|
||||||
|
}
|
||||||
|
while (serverState.task === renderRequest.task && serverState.session === 'pending') {
|
||||||
|
// Wait for task to start on server.
|
||||||
|
await asyncDelay(1500)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task started!
|
||||||
|
res = await fetch(renderRequest.stream, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
task['taskStatusLabel'].innerText = "Processing"
|
||||||
|
task['taskStatusLabel'].classList.add('activeTaskLabel')
|
||||||
|
task['taskStatusLabel'].classList.remove('waitingTaskLabel')
|
||||||
|
|
||||||
|
let stepUpdate = undefined
|
||||||
let reader = res.body.getReader()
|
let reader = res.body.getReader()
|
||||||
let textDecoder = new TextDecoder()
|
let textDecoder = new TextDecoder()
|
||||||
let finalJSON = ''
|
let finalJSON = ''
|
||||||
let prevTime = -1
|
let prevTime = -1
|
||||||
let readComplete = false
|
let readComplete = false
|
||||||
while (true) {
|
while (!readComplete || finalJSON.length > 0) {
|
||||||
let t = new Date().getTime()
|
let t = Date.now()
|
||||||
|
|
||||||
let jsonStr = ''
|
let jsonStr = ''
|
||||||
if (!readComplete) {
|
if (!readComplete) {
|
||||||
const {value, done} = await reader.read()
|
const {value, done} = await reader.read()
|
||||||
if (done) {
|
if (done) {
|
||||||
readComplete = true
|
readComplete = true
|
||||||
}
|
}
|
||||||
if (done && finalJSON.length <= 0 && !value) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if (value) {
|
if (value) {
|
||||||
jsonStr = textDecoder.decode(value)
|
jsonStr = textDecoder.decode(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stepUpdate = undefined
|
||||||
try {
|
try {
|
||||||
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
|
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
|
||||||
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
|
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
|
||||||
@ -571,9 +656,6 @@ async function doMakeImage(task) {
|
|||||||
throw e
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (readComplete && finalJSON.length <= 0) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
|
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
|
||||||
let batchSize = stepUpdate.total_steps
|
let batchSize = stepUpdate.total_steps
|
||||||
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
|
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
|
||||||
@ -598,6 +680,23 @@ async function doMakeImage(task) {
|
|||||||
showImages(reqBody, stepUpdate, outputContainer, true)
|
showImages(reqBody, stepUpdate, outputContainer, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (stepUpdate?.status) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if (readComplete && finalJSON.length <= 0) {
|
||||||
|
if (res.status === 200) {
|
||||||
|
await asyncDelay(5000)
|
||||||
|
res = await fetch(renderRequest.stream, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
})
|
||||||
|
reader = res.body.getReader()
|
||||||
|
readComplete = false
|
||||||
|
} else {
|
||||||
|
console.log('Stream stopped: ', res)
|
||||||
|
}
|
||||||
|
}
|
||||||
prevTime = t
|
prevTime = t
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,21 +719,22 @@ async function doMakeImage(task) {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
|
if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
|
||||||
if (serverStatus !== 'online') {
|
if (!isServerAvailable()) {
|
||||||
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
|
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
|
||||||
} else if (typeof res === 'object') {
|
} else if (typeof res === 'object') {
|
||||||
let msg = 'Stable Diffusion had an error reading the response: '
|
let msg = 'Stable Diffusion had an error reading the response: '
|
||||||
try { // 'Response': body stream already read
|
try { // 'Response': body stream already read
|
||||||
msg += 'Read: ' + await res.text()
|
msg += 'Read: ' + await res.text()
|
||||||
} catch(e) {
|
} catch(e) {
|
||||||
msg += 'No error response. '
|
msg += 'Unexpected end of stream. '
|
||||||
}
|
}
|
||||||
if (finalJSON) {
|
if (finalJSON) {
|
||||||
msg += 'Buffered data: ' + finalJSON
|
msg += 'Buffered data: ' + finalJSON
|
||||||
}
|
}
|
||||||
logError(msg, res, outputMsg)
|
logError(msg, res, outputMsg)
|
||||||
} else {
|
} else {
|
||||||
msg = `Unexpected Read Error:<br/><pre>Response:${res}<br/>StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
|
let msg = `Unexpected Read Error:<br/><pre>Response: ${res}<br/>StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
|
||||||
|
logError(msg, res, outputMsg)
|
||||||
}
|
}
|
||||||
progressBar.style.display = 'none'
|
progressBar.style.display = 'none'
|
||||||
return false
|
return false
|
||||||
@ -682,14 +782,14 @@ async function checkTasks() {
|
|||||||
let task = taskQueue.pop()
|
let task = taskQueue.pop()
|
||||||
currentTask = task
|
currentTask = task
|
||||||
|
|
||||||
let time = new Date().getTime()
|
let time = Date.now()
|
||||||
|
|
||||||
let successCount = 0
|
let successCount = 0
|
||||||
|
|
||||||
task.isProcessing = true
|
task.isProcessing = true
|
||||||
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
|
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
|
||||||
task['taskStatusLabel'].innerText = "Processing"
|
task['taskStatusLabel'].innerText = "Starting"
|
||||||
task['taskStatusLabel'].className += " activeTaskLabel"
|
task['taskStatusLabel'].classList.add('waitingTaskLabel')
|
||||||
|
|
||||||
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
|
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
|
||||||
const startSeed = task.reqBody.seed || task.seed
|
const startSeed = task.reqBody.seed || task.seed
|
||||||
@ -724,7 +824,7 @@ async function checkTasks() {
|
|||||||
task['stopTask'].innerHTML = '<i class="fa-solid fa-trash-can"></i> Remove'
|
task['stopTask'].innerHTML = '<i class="fa-solid fa-trash-can"></i> Remove'
|
||||||
task['taskStatusLabel'].style.display = 'none'
|
task['taskStatusLabel'].style.display = 'none'
|
||||||
|
|
||||||
time = new Date().getTime() - time
|
time = Date.now() - time
|
||||||
time /= 1000
|
time /= 1000
|
||||||
|
|
||||||
if (successCount === task.batchCount) {
|
if (successCount === task.batchCount) {
|
||||||
@ -814,8 +914,8 @@ function getCurrentUserRequest() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function makeImage() {
|
function makeImage() {
|
||||||
if (serverStatus !== 'online') {
|
if (!isServerAvailable()) {
|
||||||
alert('The server is still starting up..')
|
alert('The server is not available.')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const taskTemplate = getCurrentUserRequest()
|
const taskTemplate = getCurrentUserRequest()
|
||||||
@ -868,7 +968,7 @@ function createTask(task) {
|
|||||||
if (task['isProcessing']) {
|
if (task['isProcessing']) {
|
||||||
task.isProcessing = false
|
task.isProcessing = false
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/image/stop')
|
let res = await fetch('/image/stop?session_id=' + sessionId)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(e)
|
console.log(e)
|
||||||
}
|
}
|
||||||
@ -1008,7 +1108,7 @@ async function stopAllTasks() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/image/stop')
|
let res = await fetch('/image/stop?session_id=' + sessionId)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(e)
|
console.log(e)
|
||||||
}
|
}
|
||||||
@ -1135,9 +1235,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
|||||||
updatePromptStrength()
|
updatePromptStrength()
|
||||||
|
|
||||||
useBetaChannelField.addEventListener('click', async function(e) {
|
useBetaChannelField.addEventListener('click', async function(e) {
|
||||||
if (serverStatus !== 'online') {
|
if (!isServerAvailable()) {
|
||||||
// logError('The server is still starting up..')
|
// logError('The server is still starting up..')
|
||||||
alert('The server is still starting up..')
|
alert('The server is not available.')
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -1164,7 +1264,7 @@ useBetaChannelField.addEventListener('click', async function(e) {
|
|||||||
|
|
||||||
async function getAppConfig() {
|
async function getAppConfig() {
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/app_config')
|
let res = await fetch('/get/app_config')
|
||||||
const config = await res.json()
|
const config = await res.json()
|
||||||
|
|
||||||
if (config.update_branch === 'beta') {
|
if (config.update_branch === 'beta') {
|
||||||
@ -1180,7 +1280,7 @@ async function getAppConfig() {
|
|||||||
|
|
||||||
async function getModels() {
|
async function getModels() {
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/models')
|
let res = await fetch('/get/models')
|
||||||
const models = await res.json()
|
const models = await res.json()
|
||||||
|
|
||||||
let activeModel = models['active']
|
let activeModel = models['active']
|
||||||
@ -1451,10 +1551,10 @@ async function getDiskPath() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
let res = await fetch('/output_dir')
|
let res = await fetch('/get/output_dir')
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
res = await res.json()
|
res = await res.json()
|
||||||
res = res[0]
|
res = res.output_dir
|
||||||
|
|
||||||
document.querySelector('#diskPath').value = res
|
document.querySelector('#diskPath').value = res
|
||||||
}
|
}
|
||||||
@ -1562,14 +1662,15 @@ function resizeModifierCards(val) {
|
|||||||
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
|
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
|
||||||
card.className = classes.join(' ').trim()
|
card.className = classes.join(' ').trim()
|
||||||
|
|
||||||
if(val != 0)
|
if(val != 0) {
|
||||||
card.classList.add(cardSize(val))
|
card.classList.add(cardSize(val))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadModifiers() {
|
async function loadModifiers() {
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/modifiers.json?v=2')
|
let res = await fetch('/get/modifiers')
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
res = await res.json()
|
res = await res.json()
|
||||||
|
|
||||||
|
@ -197,6 +197,35 @@ def load_model_real_esrgan(real_esrgan_to_use):
|
|||||||
|
|
||||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
||||||
|
|
||||||
|
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
||||||
|
if disk_path is None: return None
|
||||||
|
if session_id is None: return None
|
||||||
|
if ext is None: raise Exception('Missing ext')
|
||||||
|
|
||||||
|
session_out_path = os.path.join(disk_path, session_id)
|
||||||
|
os.makedirs(session_out_path, exist_ok=True)
|
||||||
|
|
||||||
|
prompt_flattened = filename_regex.sub('_', prompt)[:50]
|
||||||
|
img_id = str(uuid.uuid4())[-8:]
|
||||||
|
|
||||||
|
if suffix is not None:
|
||||||
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
|
||||||
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
|
||||||
|
|
||||||
|
def apply_filters(filter_name, image_data):
|
||||||
|
print(f'Applying filter {filter_name}...')
|
||||||
|
gc()
|
||||||
|
|
||||||
|
if filter_name == 'gfpgan':
|
||||||
|
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
|
if filter_name == 'real_esrgan':
|
||||||
|
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||||
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
|
return image_data
|
||||||
|
|
||||||
def mk_img(req: Request):
|
def mk_img(req: Request):
|
||||||
try:
|
try:
|
||||||
yield from do_mk_img(req)
|
yield from do_mk_img(req)
|
||||||
@ -283,23 +312,11 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
opt_prompt = req.prompt
|
opt_prompt = req.prompt
|
||||||
opt_seed = req.seed
|
opt_seed = req.seed
|
||||||
opt_n_samples = req.num_outputs
|
|
||||||
opt_n_iter = 1
|
opt_n_iter = 1
|
||||||
opt_scale = req.guidance_scale
|
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_H = req.height
|
|
||||||
opt_W = req.width
|
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
opt_ddim_steps = req.num_inference_steps
|
|
||||||
opt_ddim_eta = 0.0
|
opt_ddim_eta = 0.0
|
||||||
opt_strength = req.prompt_strength
|
|
||||||
opt_save_to_disk_path = req.save_to_disk_path
|
|
||||||
opt_init_img = req.init_image
|
opt_init_img = req.init_image
|
||||||
opt_use_face_correction = req.use_face_correction
|
|
||||||
opt_use_upscale = req.use_upscale
|
|
||||||
opt_show_only_filtered = req.show_only_filtered_image
|
|
||||||
opt_format = req.output_format
|
|
||||||
opt_sampler_name = req.sampler
|
|
||||||
|
|
||||||
print(req.to_string(), '\n device', device)
|
print(req.to_string(), '\n device', device)
|
||||||
|
|
||||||
@ -307,7 +324,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
seed_everything(opt_seed)
|
seed_everything(opt_seed)
|
||||||
|
|
||||||
batch_size = opt_n_samples
|
batch_size = req.num_outputs
|
||||||
prompt = opt_prompt
|
prompt = opt_prompt
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
@ -327,7 +344,7 @@ def do_mk_img(req: Request):
|
|||||||
else:
|
else:
|
||||||
handler = _img2img
|
handler = _img2img
|
||||||
|
|
||||||
init_image = load_img(req.init_image, opt_W, opt_H)
|
init_image = load_img(req.init_image, req.width, req.height)
|
||||||
init_image = init_image.to(device)
|
init_image = init_image.to(device)
|
||||||
|
|
||||||
if device != "cpu" and precision == "autocast":
|
if device != "cpu" and precision == "autocast":
|
||||||
@ -339,7 +356,7 @@ def do_mk_img(req: Request):
|
|||||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
if req.mask is not None:
|
if req.mask is not None:
|
||||||
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||||
|
|
||||||
@ -348,12 +365,12 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
move_fs_to_cpu()
|
move_fs_to_cpu()
|
||||||
|
|
||||||
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(opt_strength * opt_ddim_steps)
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||||
print(f"target t_enc is {t_enc} steps")
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
|
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
|
||||||
os.makedirs(session_out_path, exist_ok=True)
|
os.makedirs(session_out_path, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
session_out_path = None
|
session_out_path = None
|
||||||
@ -366,7 +383,7 @@ def do_mk_img(req: Request):
|
|||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
modelCS.to(device)
|
modelCS.to(device)
|
||||||
uc = None
|
uc = None
|
||||||
if opt_scale != 1.0:
|
if req.guidance_scale != 1.0:
|
||||||
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
@ -393,7 +410,7 @@ def do_mk_img(req: Request):
|
|||||||
partial_x_samples = x_samples
|
partial_x_samples = x_samples
|
||||||
|
|
||||||
if req.stream_progress_updates:
|
if req.stream_progress_updates:
|
||||||
n_steps = opt_ddim_steps if req.init_image is None else t_enc
|
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||||
progress = {"step": i, "total_steps": n_steps}
|
progress = {"step": i, "total_steps": n_steps}
|
||||||
|
|
||||||
if req.stream_image_progress and i % 5 == 0:
|
if req.stream_image_progress and i % 5 == 0:
|
||||||
@ -425,9 +442,9 @@ def do_mk_img(req: Request):
|
|||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
|
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||||
else:
|
else:
|
||||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
||||||
|
|
||||||
yield from x_samples
|
yield from x_samples
|
||||||
|
|
||||||
@ -447,68 +464,48 @@ def do_mk_img(req: Request):
|
|||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
img = Image.fromarray(x_sample)
|
img = Image.fromarray(x_sample)
|
||||||
|
|
||||||
has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
||||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN'))
|
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
||||||
|
|
||||||
return_orig_img = not has_filters or not opt_show_only_filtered
|
return_orig_img = not has_filters or not req.show_only_filtered_image
|
||||||
|
|
||||||
if stop_processing:
|
if stop_processing:
|
||||||
return_orig_img = True
|
return_orig_img = True
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
prompt_flattened = filename_regex.sub('_', prompts[0])
|
|
||||||
prompt_flattened = prompt_flattened[:50]
|
|
||||||
|
|
||||||
img_id = str(uuid.uuid4())[-8:]
|
|
||||||
|
|
||||||
file_path = f"{prompt_flattened}_{img_id}"
|
|
||||||
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
|
|
||||||
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
|
|
||||||
|
|
||||||
if return_orig_img:
|
if return_orig_img:
|
||||||
|
img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format)
|
||||||
save_image(img, img_out_path)
|
save_image(img, img_out_path)
|
||||||
|
meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], 'txt')
|
||||||
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file)
|
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
||||||
|
|
||||||
if return_orig_img:
|
if return_orig_img:
|
||||||
img_data = img_to_base64_str(img, opt_format)
|
img_data = img_to_base64_str(img, req.output_format)
|
||||||
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
|
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
|
||||||
res.images.append(res_image_orig)
|
res.images.append(res_image_orig)
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
res_image_orig.path_abs = img_out_path
|
res_image_orig.path_abs = img_out_path
|
||||||
|
|
||||||
del img
|
del img
|
||||||
|
|
||||||
if has_filters and not stop_processing:
|
if has_filters and not stop_processing:
|
||||||
print('Applying filters..')
|
|
||||||
|
|
||||||
gc()
|
|
||||||
filters_applied = []
|
filters_applied = []
|
||||||
|
if req.use_face_correction:
|
||||||
if opt_use_face_correction:
|
x_sample = apply_filters('gfpgan', x_sample)
|
||||||
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
filters_applied.append(req.use_face_correction)
|
||||||
x_sample = output[:,:,::-1]
|
if req.use_upscale:
|
||||||
filters_applied.append(opt_use_face_correction)
|
x_sample = apply_filters('real_esrgan', x_sample)
|
||||||
|
filters_applied.append(req.use_upscale)
|
||||||
if opt_use_upscale:
|
if (len(filters_applied) > 0):
|
||||||
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
|
|
||||||
x_sample = output[:,:,::-1]
|
|
||||||
filters_applied.append(opt_use_upscale)
|
|
||||||
|
|
||||||
filtered_image = Image.fromarray(x_sample)
|
filtered_image = Image.fromarray(x_sample)
|
||||||
|
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
||||||
filtered_img_data = img_to_base64_str(filtered_image, opt_format)
|
response_image = ResponseImage(data=filtered_img_data, seed=req.seed)
|
||||||
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
res.images.append(response_image)
|
||||||
res.images.append(res_image_filtered)
|
if req.save_to_disk_path is not None:
|
||||||
|
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format, "_".join(filters_applied))
|
||||||
filters_applied = "_".join(filters_applied)
|
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
|
||||||
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
|
|
||||||
save_image(filtered_image, filtered_img_out_path)
|
save_image(filtered_image, filtered_img_out_path)
|
||||||
res_image_filtered.path_abs = filtered_img_out_path
|
response_image.path_abs = filtered_img_out_path
|
||||||
|
|
||||||
del filtered_image
|
del filtered_image
|
||||||
|
|
||||||
seeds += str(opt_seed) + ","
|
seeds += str(opt_seed) + ","
|
||||||
@ -529,9 +526,20 @@ def save_image(img, img_out_path):
|
|||||||
except:
|
except:
|
||||||
print('could not save the file', traceback.format_exc())
|
print('could not save the file', traceback.format_exc())
|
||||||
|
|
||||||
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file):
|
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
||||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}"
|
metadata = f"""{prompt}
|
||||||
|
Width: {req.width}
|
||||||
|
Height: {req.height}
|
||||||
|
Seed: {opt_seed}
|
||||||
|
Steps: {req.num_inference_steps}
|
||||||
|
Guidance Scale: {req.guidance_scale}
|
||||||
|
Prompt Strength: {req.prompt_strength}
|
||||||
|
Use Face Correction: {req.use_face_correction}
|
||||||
|
Use Upscaling: {req.use_upscale}
|
||||||
|
Sampler: {req.sampler}
|
||||||
|
Negative Prompt: {req.negative_prompt}
|
||||||
|
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with open(meta_out_path, 'w') as f:
|
with open(meta_out_path, 'w') as f:
|
||||||
f.write(metadata)
|
f.write(metadata)
|
||||||
|
298
ui/sd_internal/task_manager.py
Normal file
298
ui/sd_internal/task_manager.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
|
|
||||||
|
import queue, threading, time
|
||||||
|
from typing import Any, Generator, Hashable, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sd_internal import Request, Response
|
||||||
|
|
||||||
|
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||||
|
def __repr__(self): return self.__qualname__
|
||||||
|
def __str__(self): return self.__name__
|
||||||
|
class Symbol(metaclass=SymbolClass): pass
|
||||||
|
|
||||||
|
class ServerStates:
|
||||||
|
class Init(Symbol): pass
|
||||||
|
class LoadingModel(Symbol): pass
|
||||||
|
class Online(Symbol): pass
|
||||||
|
class Rendering(Symbol): pass
|
||||||
|
class Unavailable(Symbol): pass
|
||||||
|
|
||||||
|
class RenderTask(): # Task with output queue and completion lock.
|
||||||
|
def __init__(self, req: Request):
|
||||||
|
self.request: Request = req # Initial Request
|
||||||
|
self.response: Any = None # Copy of the last reponse
|
||||||
|
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||||
|
self.error: Exception = None
|
||||||
|
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||||
|
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||||
|
async def read_buffer_generator(self):
|
||||||
|
try:
|
||||||
|
while not self.buffer_queue.empty():
|
||||||
|
res = self.buffer_queue.get(block=False)
|
||||||
|
self.buffer_queue.task_done()
|
||||||
|
yield res
|
||||||
|
except queue.Empty as e: yield
|
||||||
|
|
||||||
|
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||||
|
class ImageRequest(BaseModel):
|
||||||
|
session_id: str = "session"
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: str = ""
|
||||||
|
init_image: str = None # base64
|
||||||
|
mask: str = None # base64
|
||||||
|
num_outputs: int = 1
|
||||||
|
num_inference_steps: int = 50
|
||||||
|
guidance_scale: float = 7.5
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
seed: int = 42
|
||||||
|
prompt_strength: float = 0.8
|
||||||
|
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||||
|
# allow_nsfw: bool = False
|
||||||
|
save_to_disk_path: str = None
|
||||||
|
turbo: bool = True
|
||||||
|
use_cpu: bool = False
|
||||||
|
use_full_precision: bool = False
|
||||||
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
|
show_only_filtered_image: bool = False
|
||||||
|
output_format: str = "jpeg" # or "png"
|
||||||
|
|
||||||
|
stream_progress_updates: bool = False
|
||||||
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
|
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||||
|
class TaskCache():
|
||||||
|
def __init__(self):
|
||||||
|
self._base = dict()
|
||||||
|
self._lock: threading.Lock = threading.RLock()
|
||||||
|
def _get_ttl_time(self, ttl: int) -> int:
|
||||||
|
return int(time.time()) + ttl
|
||||||
|
def _is_expired(self, timestamp: int) -> bool:
|
||||||
|
return int(time.time()) >= timestamp
|
||||||
|
def clean(self) -> None:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
|
||||||
|
try:
|
||||||
|
# Create a list of expired keys to delete
|
||||||
|
to_delete = []
|
||||||
|
for key in self._base:
|
||||||
|
ttl, _ = self._base[key]
|
||||||
|
if self._is_expired(ttl):
|
||||||
|
to_delete.append(key)
|
||||||
|
# Remove Items
|
||||||
|
for key in to_delete:
|
||||||
|
del self._base[key]
|
||||||
|
print(f'Session {key} expired. Data removed.')
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
def clear(self) -> None:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
|
||||||
|
try: self._base.clear()
|
||||||
|
finally: self._lock.release()
|
||||||
|
def delete(self, key: Hashable) -> bool:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
|
||||||
|
try:
|
||||||
|
if key not in self._base:
|
||||||
|
return False
|
||||||
|
del self._base[key]
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
|
||||||
|
try:
|
||||||
|
if key in self._base:
|
||||||
|
_, value = self._base.get(key)
|
||||||
|
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
|
||||||
|
try:
|
||||||
|
self._base[key] = (
|
||||||
|
self._get_ttl_time(ttl), value
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
def tryGet(self, key: Hashable) -> Any:
|
||||||
|
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
|
||||||
|
try:
|
||||||
|
ttl, value = self._base.get(key, (None, None))
|
||||||
|
if ttl is not None and self._is_expired(ttl):
|
||||||
|
print(f'Session {key} expired. Discarding data.')
|
||||||
|
self.delete(key)
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
current_state = ServerStates.Init
|
||||||
|
current_state_error:Exception = None
|
||||||
|
current_model_path = None
|
||||||
|
tasks_queue = queue.Queue()
|
||||||
|
task_cache = TaskCache()
|
||||||
|
default_model_to_load = None
|
||||||
|
|
||||||
|
def preload_model(file_path=None):
|
||||||
|
global current_state, current_state_error, current_model_path
|
||||||
|
if file_path == None:
|
||||||
|
file_path = default_model_to_load
|
||||||
|
if file_path == current_model_path:
|
||||||
|
return
|
||||||
|
current_state = ServerStates.LoadingModel
|
||||||
|
try:
|
||||||
|
from . import runtime
|
||||||
|
runtime.load_model_ckpt(ckpt_to_use=file_path)
|
||||||
|
current_model_path = file_path
|
||||||
|
current_state_error = None
|
||||||
|
current_state = ServerStates.Online
|
||||||
|
except Exception as e:
|
||||||
|
current_model_path = None
|
||||||
|
current_state_error = e
|
||||||
|
current_state = ServerStates.Unavailable
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
def thread_render():
|
||||||
|
global current_state, current_state_error, current_model_path
|
||||||
|
from . import runtime
|
||||||
|
current_state = ServerStates.Online
|
||||||
|
preload_model()
|
||||||
|
while True:
|
||||||
|
task_cache.clean()
|
||||||
|
if isinstance(current_state_error, SystemExit):
|
||||||
|
current_state = ServerStates.Unavailable
|
||||||
|
return
|
||||||
|
task = None
|
||||||
|
try:
|
||||||
|
task = tasks_queue.get(timeout=1)
|
||||||
|
except queue.Empty as e:
|
||||||
|
if isinstance(current_state_error, SystemExit):
|
||||||
|
current_state = ServerStates.Unavailable
|
||||||
|
return
|
||||||
|
else: continue
|
||||||
|
#if current_model_path != task.request.use_stable_diffusion_model:
|
||||||
|
# preload_model(task.request.use_stable_diffusion_model)
|
||||||
|
if current_state_error:
|
||||||
|
task.error = current_state_error
|
||||||
|
continue
|
||||||
|
print(f'Session {task.request.session_id} starting task {id(task)}')
|
||||||
|
try:
|
||||||
|
task.lock.acquire(blocking=False)
|
||||||
|
res = runtime.mk_img(task.request)
|
||||||
|
if current_model_path == task.request.use_stable_diffusion_model:
|
||||||
|
current_state = ServerStates.Rendering
|
||||||
|
else:
|
||||||
|
current_state = ServerStates.LoadingModel
|
||||||
|
except Exception as e:
|
||||||
|
task.error = e
|
||||||
|
task.lock.release()
|
||||||
|
tasks_queue.task_done()
|
||||||
|
print(traceback.format_exc())
|
||||||
|
continue
|
||||||
|
dataQueue = None
|
||||||
|
if task.request.stream_progress_updates:
|
||||||
|
dataQueue = task.buffer_queue
|
||||||
|
for result in res:
|
||||||
|
if current_state == ServerStates.LoadingModel:
|
||||||
|
current_state = ServerStates.Rendering
|
||||||
|
current_model_path = task.request.use_stable_diffusion_model
|
||||||
|
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||||
|
runtime.stop_processing = True
|
||||||
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
|
task.error = current_state_error
|
||||||
|
current_state_error = None
|
||||||
|
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||||
|
if dataQueue:
|
||||||
|
dataQueue.put(result)
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = json.loads(result)
|
||||||
|
task.response = result
|
||||||
|
if 'output' in result:
|
||||||
|
for out_obj in result['output']:
|
||||||
|
if 'path' in out_obj:
|
||||||
|
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||||
|
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
|
||||||
|
elif 'data' in out_obj:
|
||||||
|
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||||
|
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
|
# Task completed
|
||||||
|
task.lock.release()
|
||||||
|
tasks_queue.task_done()
|
||||||
|
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
|
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||||
|
elif task.error is not None:
|
||||||
|
print(f'Session {task.request.session_id} task {id(task)} failed!')
|
||||||
|
else:
|
||||||
|
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
||||||
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
|
render_thread = threading.Thread(target=thread_render)
|
||||||
|
|
||||||
|
def start_render_thread():
|
||||||
|
# Start Rendering Thread
|
||||||
|
render_thread.daemon = True
|
||||||
|
render_thread.start()
|
||||||
|
|
||||||
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
|
global current_state_error
|
||||||
|
current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
|
def render(req : ImageRequest):
|
||||||
|
if not render_thread.is_alive(): # Render thread is dead
|
||||||
|
raise ChildProcessError('Rendering thread has died.')
|
||||||
|
# Alive, check if task in cache
|
||||||
|
task = task_cache.tryGet(req.session_id)
|
||||||
|
if task and not task.response and not task.error and not task.lock.locked():
|
||||||
|
# Unstarted task pending, deny queueing more than one.
|
||||||
|
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
|
||||||
|
#
|
||||||
|
from . import runtime
|
||||||
|
r = Request()
|
||||||
|
r.session_id = req.session_id
|
||||||
|
r.prompt = req.prompt
|
||||||
|
r.negative_prompt = req.negative_prompt
|
||||||
|
r.init_image = req.init_image
|
||||||
|
r.mask = req.mask
|
||||||
|
r.num_outputs = req.num_outputs
|
||||||
|
r.num_inference_steps = req.num_inference_steps
|
||||||
|
r.guidance_scale = req.guidance_scale
|
||||||
|
r.width = req.width
|
||||||
|
r.height = req.height
|
||||||
|
r.seed = req.seed
|
||||||
|
r.prompt_strength = req.prompt_strength
|
||||||
|
r.sampler = req.sampler
|
||||||
|
# r.allow_nsfw = req.allow_nsfw
|
||||||
|
r.turbo = req.turbo
|
||||||
|
r.use_cpu = req.use_cpu
|
||||||
|
r.use_full_precision = req.use_full_precision
|
||||||
|
r.save_to_disk_path = req.save_to_disk_path
|
||||||
|
r.use_upscale: str = req.use_upscale
|
||||||
|
r.use_face_correction = req.use_face_correction
|
||||||
|
r.show_only_filtered_image = req.show_only_filtered_image
|
||||||
|
r.output_format = req.output_format
|
||||||
|
|
||||||
|
r.stream_progress_updates = True # the underlying implementation only supports streaming
|
||||||
|
r.stream_image_progress = req.stream_image_progress
|
||||||
|
|
||||||
|
if not req.stream_progress_updates:
|
||||||
|
r.stream_image_progress = False
|
||||||
|
|
||||||
|
new_task = RenderTask(r)
|
||||||
|
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
||||||
|
tasks_queue.put(new_task, block=True, timeout=30)
|
||||||
|
return new_task
|
||||||
|
raise RuntimeError('Failed to add task to cache.')
|
280
ui/server.py
280
ui/server.py
@ -14,90 +14,32 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
|||||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.responses import FileResponse, StreamingResponse
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
import queue, threading, time
|
||||||
|
from typing import Any, Generator, Hashable, Optional, Union
|
||||||
|
|
||||||
from sd_internal import Request, Response
|
from sd_internal import Request, Response, task_manager
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
model_loaded = False
|
|
||||||
model_is_loading = False
|
|
||||||
|
|
||||||
modifiers_cache = None
|
modifiers_cache = None
|
||||||
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
||||||
|
|
||||||
# don't show access log entries for URLs that start with the given prefix
|
# don't show access log entries for URLs that start with the given prefix
|
||||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails']
|
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||||
|
|
||||||
|
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
||||||
|
|
||||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
|
||||||
class ImageRequest(BaseModel):
|
|
||||||
session_id: str = "session"
|
|
||||||
prompt: str = ""
|
|
||||||
negative_prompt: str = ""
|
|
||||||
init_image: str = None # base64
|
|
||||||
mask: str = None # base64
|
|
||||||
num_outputs: int = 1
|
|
||||||
num_inference_steps: int = 50
|
|
||||||
guidance_scale: float = 7.5
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
seed: int = 42
|
|
||||||
prompt_strength: float = 0.8
|
|
||||||
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
|
||||||
# allow_nsfw: bool = False
|
|
||||||
save_to_disk_path: str = None
|
|
||||||
turbo: bool = True
|
|
||||||
use_cpu: bool = False
|
|
||||||
use_full_precision: bool = False
|
|
||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
|
||||||
show_only_filtered_image: bool = False
|
|
||||||
output_format: str = "jpeg" # or "png"
|
|
||||||
|
|
||||||
stream_progress_updates: bool = False
|
|
||||||
stream_image_progress: bool = False
|
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
class SetAppConfigRequest(BaseModel):
|
||||||
update_branch: str = "main"
|
update_branch: str = "main"
|
||||||
|
|
||||||
@app.get('/')
|
|
||||||
def read_root():
|
|
||||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
|
||||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
|
|
||||||
|
|
||||||
@app.get('/ping')
|
|
||||||
async def ping():
|
|
||||||
global model_loaded, model_is_loading
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model_loaded:
|
|
||||||
return {'OK'}
|
|
||||||
|
|
||||||
if model_is_loading:
|
|
||||||
return {'ERROR'}
|
|
||||||
|
|
||||||
model_is_loading = True
|
|
||||||
|
|
||||||
from sd_internal import runtime
|
|
||||||
|
|
||||||
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
|
|
||||||
|
|
||||||
model_loaded = True
|
|
||||||
model_is_loading = False
|
|
||||||
|
|
||||||
return {'OK'}
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
# needs to support the legacy installations
|
# needs to support the legacy installations
|
||||||
def get_initial_model_to_load():
|
def get_initial_model_to_load():
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||||
@ -114,7 +56,6 @@ def get_initial_model_to_load():
|
|||||||
ckpt_to_use = model_path
|
ckpt_to_use = model_path
|
||||||
else:
|
else:
|
||||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
||||||
|
|
||||||
return ckpt_to_use
|
return ckpt_to_use
|
||||||
|
|
||||||
def resolve_model_to_use(model_name):
|
def resolve_model_to_use(model_name):
|
||||||
@ -126,92 +67,110 @@ def resolve_model_to_use(model_name):
|
|||||||
model_path = legacy_model_path
|
model_path = legacy_model_path
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
|
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
|
@app.get('/')
|
||||||
|
def read_root():
|
||||||
|
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
|
@app.get('/ping') # Get server and optionally session status.
|
||||||
|
def ping(session_id:str=None):
|
||||||
|
if not task_manager.render_thread.is_alive(): # Render thread is dead.
|
||||||
|
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
|
||||||
|
return HTTPException(status_code=500, detail='Render thread is dead.')
|
||||||
|
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
|
||||||
|
# Alive
|
||||||
|
response = {'status': str(task_manager.current_state)}
|
||||||
|
if session_id:
|
||||||
|
task = task_manager.task_cache.tryGet(session_id)
|
||||||
|
if task:
|
||||||
|
response['task'] = id(task)
|
||||||
|
if task.lock.locked():
|
||||||
|
response['session'] = 'running'
|
||||||
|
elif isinstance(task.error, StopAsyncIteration):
|
||||||
|
response['session'] = 'stopped'
|
||||||
|
elif task.error:
|
||||||
|
response['session'] = 'error'
|
||||||
|
elif not task.buffer_queue.empty():
|
||||||
|
response['session'] = 'buffer'
|
||||||
|
elif task.response:
|
||||||
|
response['session'] = 'completed'
|
||||||
|
else:
|
||||||
|
response['session'] = 'pending'
|
||||||
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
def save_model_to_config(model_name):
|
def save_model_to_config(model_name):
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if 'model' not in config:
|
if 'model' not in config:
|
||||||
config['model'] = {}
|
config['model'] = {}
|
||||||
|
|
||||||
config['model']['stable-diffusion'] = model_name
|
config['model']['stable-diffusion'] = model_name
|
||||||
|
|
||||||
setConfig(config)
|
setConfig(config)
|
||||||
|
|
||||||
@app.post('/image')
|
@app.post('/render')
|
||||||
def image(req : ImageRequest):
|
def render(req : task_manager.ImageRequest):
|
||||||
from sd_internal import runtime
|
|
||||||
|
|
||||||
r = Request()
|
|
||||||
r.session_id = req.session_id
|
|
||||||
r.prompt = req.prompt
|
|
||||||
r.negative_prompt = req.negative_prompt
|
|
||||||
r.init_image = req.init_image
|
|
||||||
r.mask = req.mask
|
|
||||||
r.num_outputs = req.num_outputs
|
|
||||||
r.num_inference_steps = req.num_inference_steps
|
|
||||||
r.guidance_scale = req.guidance_scale
|
|
||||||
r.width = req.width
|
|
||||||
r.height = req.height
|
|
||||||
r.seed = req.seed
|
|
||||||
r.prompt_strength = req.prompt_strength
|
|
||||||
r.sampler = req.sampler
|
|
||||||
# r.allow_nsfw = req.allow_nsfw
|
|
||||||
r.turbo = req.turbo
|
|
||||||
r.use_cpu = req.use_cpu
|
|
||||||
r.use_full_precision = req.use_full_precision
|
|
||||||
r.save_to_disk_path = req.save_to_disk_path
|
|
||||||
r.use_upscale: str = req.use_upscale
|
|
||||||
r.use_face_correction = req.use_face_correction
|
|
||||||
r.show_only_filtered_image = req.show_only_filtered_image
|
|
||||||
r.output_format = req.output_format
|
|
||||||
|
|
||||||
r.stream_progress_updates = True # the underlying implementation only supports streaming
|
|
||||||
r.stream_image_progress = req.stream_image_progress
|
|
||||||
|
|
||||||
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
|
||||||
|
|
||||||
save_model_to_config(req.use_stable_diffusion_model)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not req.stream_progress_updates:
|
save_model_to_config(req.use_stable_diffusion_model)
|
||||||
r.stream_image_progress = False
|
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
||||||
|
new_task = task_manager.render(req)
|
||||||
res = runtime.mk_img(r)
|
response = {
|
||||||
|
'status': str(task_manager.current_state),
|
||||||
if req.stream_progress_updates:
|
'queue': task_manager.tasks_queue.qsize(),
|
||||||
return StreamingResponse(res, media_type='application/json')
|
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||||
else: # compatibility mode: buffer the streaming responses, and return the last one
|
'task': id(new_task)
|
||||||
last_result = None
|
}
|
||||||
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
for result in res:
|
except ChildProcessError as e: # Render thread is dead
|
||||||
last_result = result
|
return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||||
|
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
|
||||||
return json.loads(last_result)
|
return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
||||||
|
def stream(session_id:str, task_id:int):
|
||||||
|
#TODO Move to WebSockets ??
|
||||||
|
task = task_manager.task_cache.tryGet(session_id)
|
||||||
|
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
||||||
|
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||||
|
if task.buffer_queue.empty() and not task.lock.locked():
|
||||||
|
if task.response:
|
||||||
|
#print(f'Session {session_id} sending cached response')
|
||||||
|
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||||
|
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||||
|
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||||
|
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||||
|
|
||||||
@app.get('/image/stop')
|
@app.get('/image/stop')
|
||||||
def stop():
|
def stop(session_id:str=None):
|
||||||
try:
|
if not session_id:
|
||||||
if model_is_loading:
|
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||||
return {'ERROR'}
|
return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||||
|
task_manager.current_state_error = StopAsyncIteration('')
|
||||||
from sd_internal import runtime
|
return {'OK'}
|
||||||
runtime.stop_processing = True
|
task = task_manager.task_cache.tryGet(session_id)
|
||||||
|
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
||||||
|
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
||||||
|
task.error = StopAsyncIteration('')
|
||||||
return {'OK'}
|
return {'OK'}
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app.get('/image/tmp/{session_id}/{img_id}')
|
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
||||||
def get_image(session_id, img_id):
|
def get_image(session_id, img_id):
|
||||||
from sd_internal import runtime
|
task = task_manager.task_cache.tryGet(session_id)
|
||||||
buf = runtime.temp_images[session_id + '/' + img_id]
|
if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
||||||
buf.seek(0)
|
if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||||
return StreamingResponse(buf, media_type='image/jpeg')
|
try:
|
||||||
|
img_data = task.temp_images[img_id]
|
||||||
|
if isinstance(img_data, str):
|
||||||
|
return img_data
|
||||||
|
img_data.seek(0)
|
||||||
|
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||||
|
except KeyError as e:
|
||||||
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post('/app_config')
|
@app.post('/app_config')
|
||||||
async def setAppConfig(req : SetAppConfigRequest):
|
async def setAppConfig(req : SetAppConfigRequest):
|
||||||
@ -242,42 +201,27 @@ async def setAppConfig(req : SetAppConfigRequest):
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get('/app_config')
|
def getConfig(default_val={}):
|
||||||
def getAppConfig():
|
|
||||||
try:
|
try:
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||||
|
|
||||||
if not os.path.exists(config_json_path):
|
if not os.path.exists(config_json_path):
|
||||||
return HTTPException(status_code=500, detail="No config file")
|
return default_val
|
||||||
|
|
||||||
with open(config_json_path, 'r') as f:
|
with open(config_json_path, 'r') as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return default_val
|
||||||
|
|
||||||
def getConfig():
|
|
||||||
try:
|
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
|
||||||
|
|
||||||
if not os.path.exists(config_json_path):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
with open(config_json_path, 'r') as f:
|
|
||||||
return json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def setConfig(config):
|
def setConfig(config):
|
||||||
try:
|
try:
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||||
|
|
||||||
with open(config_json_path, 'w') as f:
|
with open(config_json_path, 'w') as f:
|
||||||
return json.dump(config, f)
|
return json.dump(config, f)
|
||||||
except:
|
except:
|
||||||
|
print(str(e))
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@app.get('/models')
|
|
||||||
def getModels():
|
def getModels():
|
||||||
models = {
|
models = {
|
||||||
'active': {
|
'active': {
|
||||||
@ -307,14 +251,21 @@ def getModels():
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@app.get('/modifiers.json')
|
@app.get('/get/{key:path}')
|
||||||
def read_modifiers():
|
def read_web_data(key:str=None):
|
||||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||||
return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers)
|
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||||
|
elif key == 'app_config':
|
||||||
@app.get('/output_dir')
|
config = getConfig(default_val=None)
|
||||||
def read_home_dir():
|
if config is None:
|
||||||
return {outpath}
|
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||||
|
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'models':
|
||||||
|
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||||
|
else:
|
||||||
|
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||||
|
|
||||||
# don't log certain requests
|
# don't log certain requests
|
||||||
class LogSuppressFilter(logging.Filter):
|
class LogSuppressFilter(logging.Filter):
|
||||||
@ -323,10 +274,11 @@ class LogSuppressFilter(logging.Filter):
|
|||||||
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
||||||
if path.find(prefix) != -1:
|
if path.find(prefix) != -1:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||||
|
|
||||||
|
task_manager.default_model_to_load = get_initial_model_to_load()
|
||||||
|
task_manager.start_render_thread()
|
||||||
|
|
||||||
# start the browser ui
|
# start the browser ui
|
||||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
import webbrowser; webbrowser.open('http://localhost:9000')
|
Loading…
Reference in New Issue
Block a user