Add support for img2img

This commit is contained in:
cmdr2 2022-08-25 21:46:31 +05:30
parent fd7519c444
commit 2c4162b15c
3 changed files with 150 additions and 41 deletions

View File

@ -5,7 +5,7 @@ services:
container_name: sd container_name: sd
ports: ports:
- '5000:5000' - '5000:5000'
image: 'r8.im/stability-ai/stable-diffusion@sha256:06eb78b36068500c616a7f33c15e6fa40404f8e14b5bfad57ebe0c7fe0f6bdf1' image: 'r8.im/andreasjansson/stable-diffusion-wip@sha256:984cf13f8875bf9eec0860e87ce41143f40b6f07f0112505c6833147763353ad'
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@ -6,6 +6,9 @@
font-family: Arial, Helvetica, sans-serif; font-family: Arial, Helvetica, sans-serif;
font-size: 11pt; font-size: 11pt;
} }
label {
font-size: 10pt;
}
#prompt { #prompt {
width: 50vw; width: 50vw;
height: 50pt; height: 50pt;
@ -15,6 +18,23 @@
width: 95%; width: 95%;
} }
} }
#init_image_preview_container {
display: none;
}
#init_image_clear {
position: absolute;
transform: translateX(-50%);
background: black;
color: white;
border: 2pt solid #ccc;
padding: 0;
cursor: pointer;
outline: inherit;
border-radius: 8pt;
width: 16pt;
height: 16pt;
font-size: 10pt;
}
#configHeader { #configHeader {
margin-top: 5px; margin-top: 5px;
margin-bottom: 5px; margin-bottom: 5px;
@ -29,10 +49,21 @@
font-size: small; font-size: small;
} }
#footer { #footer {
margin-top: 5px; border-top: 1px solid #999;
padding-top: 5px; margin-top: 10px;
padding-top: 10px;
font-size: small; font-size: small;
} }
.imgUseBtn {
position: absolute;
transform: translateX(-100%);
margin-top: 5pt;
margin-left: -5pt;
}
.imgItem {
display: inline;
padding-right: 10px;
}
</style> </style>
</html> </html>
<body> <body>
@ -43,6 +74,12 @@
<b>Prompt:</b><br/> <b>Prompt:</b><br/>
<textarea id="prompt">a photograph of an astronaut riding a horse</textarea><br/> <textarea id="prompt">a photograph of an astronaut riding a horse</textarea><br/>
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /> </button><br/>
<div id="init_image_preview_container">
<img id="init_image_preview" src="" width="100" height="100" />
<button id="init_image_clear">X</button>
</div>
<div id="configHeader"><b>Advanced settings:</b> [<a id="configToggleBtn" href="#">show</a>]</div> <div id="configHeader"><b>Advanced settings:</b> [<a id="configToggleBtn" href="#">show</a>]</div>
<div id="config"> <div id="config">
<label for="seed">Seed:</label> <input id="seed" name="seed" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label> <br/> <label for="seed">Seed:</label> <input id="seed" name="seed" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label> <br/>
@ -70,6 +107,28 @@
<script> <script>
const HEALTH_PING_INTERVAL = 5 // seconds const HEALTH_PING_INTERVAL = 5 // seconds
let promptField = document.querySelector('#prompt')
let numOutputsField = document.querySelector('#num_outputs')
let numInferenceStepsField = document.querySelector('#num_inference_steps')
let guidanceScaleField = document.querySelector('#guidance_scale')
let guidanceScaleValueLabel = document.querySelector('#guidance_scale_value')
let randomSeedField = document.querySelector("#random_seed")
let seedField = document.querySelector('#seed')
let widthField = document.querySelector('#width')
let heightField = document.querySelector('#height')
let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview")
let makeImageBtn = document.querySelector('#makeImage')
let imagesContainer = document.querySelector('#images')
let initImagePreviewContainer = document.querySelector('#init_image_preview_container')
let initImageClearBtn = document.querySelector('#init_image_clear')
let showConfigToggle = document.querySelector('#configToggleBtn')
let configBox = document.querySelector('#config')
let outputMsg = document.querySelector('#outputMsg')
let serverStatus = 'offline' let serverStatus = 'offline'
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
@ -117,11 +176,9 @@ async function healthCheck() {
async function makeImage() { async function makeImage() {
setStatus('request', 'fetching..') setStatus('request', 'fetching..')
let btn = document.querySelector('#makeImage') makeImageBtn.innerHTML = 'Processing..'
btn.innerHTML = 'Processing..' makeImageBtn.disabled = true
btn.disabled = true;
let outputMsg = document.querySelector('#outputMsg')
outputMsg.innerHTML = 'Fetching..' outputMsg.innerHTML = 'Fetching..'
function logError(msg, res) { function logError(msg, res) {
@ -130,18 +187,22 @@ async function makeImage() {
setStatus('request', 'error', 'error') setStatus('request', 'error', 'error')
} }
let random_seed = document.querySelector("#random_seed") let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000) : seedField.value)
let seed = (random_seed.checked ? Math.floor(Math.random() * 10000) : document.querySelector('#seed').value)
let reqBody = { let reqBody = {
prompt: document.querySelector('#prompt').value, prompt: promptField.value,
num_outputs: document.querySelector('#num_outputs').value, num_outputs: numOutputsField.value,
num_inference_steps: document.querySelector('#num_inference_steps').value, num_inference_steps: numInferenceStepsField.value,
guidance_scale: document.querySelector('#guidance_scale').value / 10, guidance_scale: guidanceScaleField.value / 10,
width: document.querySelector('#width').value, width: widthField.value,
height: document.querySelector('#height').value, height: heightField.value,
seed: seed, seed: seed,
} }
if (initImagePreview.src.indexOf('data:image/png;base64') !== -1) {
reqBody['init_image'] = initImagePreview.src
}
let res = '' let res = ''
let time = new Date().getTime() let time = new Date().getTime()
@ -180,8 +241,8 @@ async function makeImage() {
setStatus('request', 'error', 'error') setStatus('request', 'error', 'error')
} }
btn.innerHTML = 'Make Image' makeImageBtn.innerHTML = 'Make Image'
btn.disabled = false; makeImageBtn.disabled = false
playSound() playSound()
@ -194,8 +255,7 @@ async function makeImage() {
outputMsg.innerHTML = 'Processed in ' + time + ' seconds. Seed: ' + seed outputMsg.innerHTML = 'Processed in ' + time + ' seconds. Seed: ' + seed
let images = document.querySelector('#images') imagesContainer.innerHTML = ''
images.innerHTML = ''
for (let idx in res.output) { for (let idx in res.output) {
let imgBody = '' let imgBody = ''
@ -208,55 +268,97 @@ async function makeImage() {
return return
} }
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
let img = document.createElement('img') let img = document.createElement('img')
img.width = parseInt(reqBody.width) img.width = parseInt(reqBody.width)
img.height = parseInt(reqBody.height) img.height = parseInt(reqBody.height)
img.src = imgBody img.src = imgBody
images.appendChild(img) let imgUseBtn = document.createElement('button')
imgUseBtn.className = 'imgUseBtn'
imgUseBtn.innerHTML = 'Use as Input'
imgItem.appendChild(img)
imgItem.appendChild(imgUseBtn)
imagesContainer.appendChild(imgItem)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = imgBody
initImagePreviewContainer.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
seedField.disabled = false
})
} }
setStatus('request', 'done', 'success') setStatus('request', 'done', 'success')
if (random_seed.checked) { if (randomSeedField.checked) {
let seedEl = document.querySelector("#seed") seedField.value = seed
seedEl.value = seed
} }
} }
document.querySelector('#makeImage').addEventListener('click', makeImage) makeImageBtn.addEventListener('click', makeImage)
let config = document.querySelector('#config') configBox.style.display = 'none'
config.style.display = 'none'
document.querySelector('#configToggleBtn').addEventListener('click', function() { showConfigToggle.addEventListener('click', function() {
config.style.display = (config.style.display === 'none' ? 'block' : 'none') configBox.style.display = (configBox.style.display === 'none' ? 'block' : 'none')
document.querySelector('#configToggleBtn').innerHTML = (config.style.display === 'none' ? 'show' : 'hide') showConfigToggle.innerHTML = (configBox.style.display === 'none' ? 'show' : 'hide')
return false return false
}) })
let guidanceScale = document.querySelector('#guidance_scale')
function updateGuidanceScale() { function updateGuidanceScale() {
let label = document.querySelector('#guidance_scale_value') guidanceScaleValueLabel.innerHTML = guidanceScaleField.value / 10
label.innerHTML = guidanceScale.value / 10
} }
guidanceScale.addEventListener('input', updateGuidanceScale) guidanceScaleField.addEventListener('input', updateGuidanceScale)
updateGuidanceScale() updateGuidanceScale()
let random_seed = document.querySelector("#random_seed")
function checkRandomSeed() { function checkRandomSeed() {
let seed = document.querySelector("#seed") if (randomSeedField.checked) {
if (random_seed.checked) { seedField.disabled = true
seed.disabled = true seedField.value = "random"
seed.value = "random"
} else { } else {
seed.disabled = false seedField.disabled = false
} }
} }
random_seed.addEventListener('input', checkRandomSeed) randomSeedField.addEventListener('input', checkRandomSeed)
checkRandomSeed() checkRandomSeed()
function showInitImagePreview() {
if (initImageSelector.files.length === 0) {
initImagePreviewContainer.style.display = 'none'
return
}
let reader = new FileReader()
let file = initImageSelector.files[0]
reader.addEventListener('load', function() {
// console.log(file.name, reader.result)
initImagePreview.src = reader.result
initImagePreviewContainer.style.display = 'block'
})
if (file) {
reader.readAsDataURL(file)
}
}
initImageSelector.addEventListener('change', showInitImagePreview)
showInitImagePreview()
initImageClearBtn.addEventListener('click', function() {
initImageSelector.value = null
initImagePreview.src = ''
initImagePreviewContainer.style.display = 'none'
})
setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000) setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000)
</script> </script>

View File

@ -1,4 +1,4 @@
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse from starlette.responses import FileResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -12,6 +12,7 @@ app = FastAPI()
# defaults from https://huggingface.co/blog/stable_diffusion # defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel): class ImageRequest(BaseModel):
prompt: str prompt: str
init_image: str = None # base64
num_outputs: str = "1" num_outputs: str = "1"
num_inference_steps: str = "50" num_inference_steps: str = "50"
guidance_scale: str = "7.5" guidance_scale: str = "7.5"
@ -45,10 +46,16 @@ async def image(req : ImageRequest):
} }
} }
if req.init_image is not None:
data['input']['init_image'] = req.init_image
if req.seed == "-1": if req.seed == "-1":
del data['input']['seed'] del data['input']['seed']
res = requests.post(PREDICT_URL, json=data) res = requests.post(PREDICT_URL, json=data)
if res.status_code != 200:
raise HTTPException(status_code=500, detail=res.text)
return res.json() return res.json()
@app.get('/media/ding.mp3') @app.get('/media/ding.mp3')