Merge pull request #185 from cmdr2/beta

Stop button, finer control for guidance scale and prompt weight and a few bug fixes
This commit is contained in:
cmdr2 2022-09-14 11:47:47 +05:30 committed by GitHub
commit 64cc2567bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 202 additions and 27 deletions

View File

@ -17,6 +17,8 @@
@call git pull @call git pull
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
@call git apply ..\ui\sd_internal\ddim_callback.patch
@cd .. @cd ..
) else ( ) else (
@echo. & echo "Downloading Stable Diffusion.." & echo. @echo. & echo "Downloading Stable Diffusion.." & echo.
@ -31,6 +33,9 @@
@cd stable-diffusion @cd stable-diffusion
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
@call git apply ..\ui\sd_internal\ddim_callback.patch
@cd .. @cd ..
) )
@ -295,6 +300,8 @@ call WHERE uvicorn > .tmp
@set SD_UI_PATH=%cd%\ui @set SD_UI_PATH=%cd%\ui
@cd stable-diffusion @cd stable-diffusion
@call python --version
@uvicorn server:app --app-dir "%SD_UI_PATH%" --port 9000 --host 0.0.0.0 @uvicorn server:app --app-dir "%SD_UI_PATH%" --port 9000 --host 0.0.0.0
@pause @pause

View File

@ -18,6 +18,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git pull git pull
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
git apply ../ui/sd_internal/ddim_callback.patch
cd .. cd ..
else else
printf "\n\nDownloading Stable Diffusion..\n\n" printf "\n\nDownloading Stable Diffusion..\n\n"
@ -32,6 +34,9 @@ else
cd stable-diffusion cd stable-diffusion
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
git apply ../ui/sd_internal/ddim_callback.patch
cd .. cd ..
fi fi
@ -287,6 +292,8 @@ cd ..
export SD_UI_PATH=`pwd`/ui export SD_UI_PATH=`pwd`/ui
cd stable-diffusion cd stable-diffusion
python --version
uvicorn server:app --app-dir "$SD_UI_PATH" --port 9000 --host 0.0.0.0 uvicorn server:app --app-dir "$SD_UI_PATH" --port 9000 --host 0.0.0.0
read -p "Press any key to continue" read -p "Press any key to continue"

View File

@ -54,7 +54,7 @@
#editor-settings-entries li { #editor-settings-entries li {
padding-bottom: 3pt; padding-bottom: 3pt;
} }
#guidance_scale { .editor-slider {
transform: translateY(30%); transform: translateY(30%);
} }
#outputMsg { #outputMsg {
@ -150,6 +150,19 @@
#makeImage:hover { #makeImage:hover {
background: rgb(93, 0, 214); background: rgb(93, 0, 214);
} }
#stopImage {
flex: 0 0 70px;
background: rgb(132, 8, 0);
border: 2px solid rgb(122, 29, 0);
color: rgb(255, 221, 255);
width: 100%;
height: 30pt;
border-radius: 6px;
display: none;
}
#stopImage:hover {
background: rgb(214, 32, 0);
}
.flex-container { .flex-container {
display: flex; display: flex;
} }
@ -292,6 +305,7 @@
</div> </div>
<button id="makeImage">Make Image</button> <button id="makeImage">Make Image</button>
<button id="stopImage">Stop</button>
</div> </div>
<div class="line-separator">&nbsp;</div> <div class="line-separator">&nbsp;</div>
@ -358,8 +372,8 @@
</select> </select>
</li> </li>
<li><label for="num_inference_steps">Number of inference steps:</label> <input id="num_inference_steps" name="num_inference_steps" size="4" value="50"></li> <li><label for="num_inference_steps">Number of inference steps:</label> <input id="num_inference_steps" name="num_inference_steps" size="4" value="50"></li>
<li><label for="guidance_scale">Guidance Scale:</label> <input id="guidance_scale" name="guidance_scale" value="75" type="range" min="10" max="200"> <span id="guidance_scale_value"></span></li> <li><label for="guidance_scale_slider">Guidance Scale:</label> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="10" max="200"> <input id="guidance_scale" name="guidance_scale" size="4"></li>
<li><span id="prompt_strength_container"><label for="prompt_strength">Prompt Strength:</label> <input id="prompt_strength" name="prompt_strength" value="8" type="range" min="0" max="10"> <span id="prompt_strength_value"></span><br/></span></li> <li><span id="prompt_strength_container"><label for="prompt_strength_slider">Prompt Strength:</label> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4"><br/></span></li>
<li>&nbsp;</li> <li>&nbsp;</li>
<li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li> <li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li>
<li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li> <li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li>
@ -423,8 +437,8 @@ let promptField = document.querySelector('#prompt')
let numOutputsTotalField = document.querySelector('#num_outputs_total') let numOutputsTotalField = document.querySelector('#num_outputs_total')
let numOutputsParallelField = document.querySelector('#num_outputs_parallel') let numOutputsParallelField = document.querySelector('#num_outputs_parallel')
let numInferenceStepsField = document.querySelector('#num_inference_steps') let numInferenceStepsField = document.querySelector('#num_inference_steps')
let guidanceScaleSlider = document.querySelector('#guidance_scale_slider')
let guidanceScaleField = document.querySelector('#guidance_scale') let guidanceScaleField = document.querySelector('#guidance_scale')
let guidanceScaleValueLabel = document.querySelector('#guidance_scale_value')
let randomSeedField = document.querySelector("#random_seed") let randomSeedField = document.querySelector("#random_seed")
let seedField = document.querySelector('#seed') let seedField = document.querySelector('#seed')
let widthField = document.querySelector('#width') let widthField = document.querySelector('#width')
@ -440,8 +454,8 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath') let diskPathField = document.querySelector('#diskPath')
// let allowNSFWField = document.querySelector("#allow_nsfw") // let allowNSFWField = document.querySelector("#allow_nsfw")
let useBetaChannelField = document.querySelector("#use_beta_channel") let useBetaChannelField = document.querySelector("#use_beta_channel")
let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength') let promptStrengthField = document.querySelector('#prompt_strength')
let promptStrengthValueLabel = document.querySelector('#prompt_strength_value')
let useFaceCorrectionField = document.querySelector("#use_face_correction") let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale") let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model") let upscaleModelField = document.querySelector("#upscale_model")
@ -449,6 +463,7 @@ let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_ima
let updateBranchLabel = document.querySelector("#updateBranchLabel") let updateBranchLabel = document.querySelector("#updateBranchLabel")
let makeImageBtn = document.querySelector('#makeImage') let makeImageBtn = document.querySelector('#makeImage')
let stopImageBtn = document.querySelector('#stopImage')
let imagesContainer = document.querySelector('#current-images') let imagesContainer = document.querySelector('#current-images')
let initImagePreviewContainer = document.querySelector('#init_image_preview_container') let initImagePreviewContainer = document.querySelector('#init_image_preview_container')
@ -480,6 +495,7 @@ let modifiersPanelHandle = document.querySelector("#editor-modifiers .collapsibl
let serverStatus = 'offline' let serverStatus = 'offline'
let activeTags = [] let activeTags = []
let lastPromptUsed = '' let lastPromptUsed = ''
let taskStopped = true
function getLocalStorageItem(key, fallback) { function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key) let item = localStorage.getItem(key)
@ -616,6 +632,10 @@ async function healthCheck() {
// makes a single image. don't call this directly, use makeImage() instead // makes a single image. don't call this directly, use makeImage() instead
async function doMakeImage(reqBody) { async function doMakeImage(reqBody) {
if (taskStopped) {
return
}
let res = '' let res = ''
let seed = reqBody['seed'] let seed = reqBody['seed']
@ -782,7 +802,10 @@ async function makeImage() {
setStatus('request', 'fetching..') setStatus('request', 'fetching..')
makeImageBtn.innerHTML = 'Processing..' makeImageBtn.innerHTML = 'Processing..'
makeImageBtn.disabled = true makeImageBtn.style.display = 'none'
stopImageBtn.style.display = 'block'
taskStopped = false
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value)) let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
let numOutputsTotal = parseInt(numOutputsTotalField.value) let numOutputsTotal = parseInt(numOutputsTotalField.value)
@ -802,7 +825,7 @@ async function makeImage() {
prompt: prompt, prompt: prompt,
num_outputs: batchSize, num_outputs: batchSize,
num_inference_steps: numInferenceStepsField.value, num_inference_steps: numInferenceStepsField.value,
guidance_scale: parseInt(guidanceScaleField.value) / 10, guidance_scale: guidanceScaleField.value,
width: widthField.value, width: widthField.value,
height: heightField.value, height: heightField.value,
// allow_nsfw: allowNSFWField.checked, // allow_nsfw: allowNSFWField.checked,
@ -813,7 +836,7 @@ async function makeImage() {
if (IMAGE_REGEX.test(initImagePreview.src)) { if (IMAGE_REGEX.test(initImagePreview.src)) {
reqBody['init_image'] = initImagePreview.src reqBody['init_image'] = initImagePreview.src
reqBody['prompt_strength'] = parseInt(promptStrengthField.value) / 10 reqBody['prompt_strength'] = promptStrengthField.value
// if (IMAGE_REGEX.test(maskImagePreview.src)) { // if (IMAGE_REGEX.test(maskImagePreview.src)) {
// reqBody['mask'] = maskImagePreview.src // reqBody['mask'] = maskImagePreview.src
@ -854,6 +877,8 @@ async function makeImage() {
makeImageBtn.innerHTML = 'Make Image' makeImageBtn.innerHTML = 'Make Image'
makeImageBtn.disabled = false makeImageBtn.disabled = false
makeImageBtn.style.display = 'block'
stopImageBtn.style.display = 'none'
if (isSoundEnabled()) { if (isSoundEnabled()) {
playSound() playSound()
@ -880,9 +905,9 @@ function createFileName() {
// Most important information is the prompt // Most important information is the prompt
let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_') let underscoreName = lastPromptUsed.replace(/[^a-zA-Z0-9]/g, '_')
underscoreName = underscoreName.substring(0, 100) underscoreName = underscoreName.substring(0, 100)
const seed = seedField.value; const seed = seedField.value
const steps = numInferenceStepsField.value; const steps = numInferenceStepsField.value
const guidance = guidanceScaleField.value; const guidance = guidanceScaleField.value
// name and the top level metadata // name and the top level metadata
let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}` let fileName = `${underscoreName}_Seed-${seed}_Steps-${steps}_Guidance-${guidance}`
@ -909,6 +934,18 @@ function createFileName() {
return fileName return fileName
} }
stopImageBtn.addEventListener('click', async function() {
try {
let res = await fetch('/image/stop')
} catch (e) {
console.log(e)
}
stopImageBtn.style.display = 'none'
makeImageBtn.style.display = 'block'
taskStopped = true
})
soundToggle.addEventListener('click', handleBoolSettingChange(SOUND_ENABLED_KEY)) soundToggle.addEventListener('click', handleBoolSettingChange(SOUND_ENABLED_KEY))
soundToggle.checked = isSoundEnabled() soundToggle.checked = isSoundEnabled()
@ -964,17 +1001,39 @@ makeImageBtn.addEventListener('click', makeImage)
function updateGuidanceScale() { function updateGuidanceScale() {
guidanceScaleValueLabel.innerHTML = guidanceScaleField.value / 10 guidanceScaleField.value = guidanceScaleSlider.value / 10
} }
guidanceScaleField.addEventListener('input', updateGuidanceScale) function updateGuidanceScaleSlider() {
if (guidanceScaleField.value < 0) {
guidanceScaleField.value = 0
} else if (guidanceScaleField.value > 20) {
guidanceScaleField.value = 20
}
guidanceScaleSlider.value = guidanceScaleField.value * 10
}
guidanceScaleSlider.addEventListener('input', updateGuidanceScale)
guidanceScaleField.addEventListener('input', updateGuidanceScaleSlider)
updateGuidanceScale() updateGuidanceScale()
function updatePromptStrength() { function updatePromptStrength() {
promptStrengthValueLabel.innerHTML = promptStrengthField.value / 10 promptStrengthField.value = promptStrengthSlider.value / 100
} }
promptStrengthField.addEventListener('input', updatePromptStrength) function updatePromptStrengthSlider() {
if (promptStrengthField.value < 0) {
promptStrengthField.value = 0
} else if (promptStrengthField.value > 0.99) {
promptStrengthField.value = 0.99
}
promptStrengthSlider.value = promptStrengthField.value * 100
}
promptStrengthSlider.addEventListener('input', updatePromptStrength)
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength() updatePromptStrength()
useBetaChannelField.addEventListener('click', async function(e) { useBetaChannelField.addEventListener('click', async function(e) {

View File

@ -21,6 +21,20 @@ class Request:
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
def json(self):
return {
"prompt": self.prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
}
def to_string(self): def to_string(self):
return f''' return f'''
prompt: {self.prompt} prompt: {self.prompt}
@ -42,6 +56,7 @@ class Image:
data: str # base64 data: str # base64
seed: int seed: int
is_nsfw: bool is_nsfw: bool
path_abs: str = None
def __init__(self, data, seed): def __init__(self, data, seed):
self.data = data self.data = data
@ -51,14 +66,19 @@ class Image:
return { return {
"data": self.data, "data": self.data,
"seed": self.seed, "seed": self.seed,
"path_abs": self.path_abs,
} }
class Response: class Response:
request: Request
session_id: str
images: list images: list
def json(self): def json(self):
res = { res = {
"status": 'succeeded', "status": 'succeeded',
"session_id": self.session_id,
"request": self.request.json(),
"output": [], "output": [],
} }

View File

@ -0,0 +1,34 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index dcf7901..1f99adc 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -528,7 +528,8 @@ class UNet(DDPM):
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback)
# elif sampler == "euler":
# cvd = CompVisDenoiser(self.alphas_cumprod)
@@ -687,7 +688,8 @@ class UNet(DDPM):
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
@@ -710,6 +712,9 @@ class UNet(DDPM):
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
+
+ if callback: callback(i)
+ if img_callback: img_callback(x_dec, i)
if mask is not None:
return x0 * mask + (1. - mask) * x_dec

View File

@ -35,6 +35,7 @@ from io import BytesIO
# local # local
session_id = str(uuid.uuid4())[-8:] session_id = str(uuid.uuid4())[-8:]
stop_processing = False
ckpt_file = None ckpt_file = None
gfpgan_file = None gfpgan_file = None
@ -185,8 +186,13 @@ def load_model_real_esrgan(real_esrgan_to_use):
def mk_img(req: Request): def mk_img(req: Request):
global modelFS, device global modelFS, device
global model_gfpgan, model_real_esrgan global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
res = Response() res = Response()
res.session_id = session_id
res.request = req
res.images = [] res.images = []
model.turbo = req.turbo model.turbo = req.turbo
@ -267,7 +273,7 @@ def mk_img(req: Request):
else: else:
handler = _img2img handler = _img2img
init_image = load_img(req.init_image) init_image = load_img(req.init_image, opt_W, opt_H)
init_image = init_image.to(device) init_image = init_image.to(device)
if device != "cpu" and precision == "autocast": if device != "cpu" and precision == "autocast":
@ -320,11 +326,26 @@ def mk_img(req: Request):
else: else:
c = modelCS.get_learned_conditioning(prompts) c = modelCS.get_learned_conditioning(prompts)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
partial_x_samples = x_samples
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler # run the handler
try:
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed) x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed) x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
except UserInitiatedStop:
if partial_x_samples is None:
continue
x_samples = partial_x_samples
modelFS.to(device) modelFS.to(device)
@ -354,7 +375,11 @@ def mk_img(req: Request):
if not opt_show_only_filtered: if not opt_show_only_filtered:
img_data = img_to_base64_str(img) img_data = img_to_base64_str(img)
res.images.append(ResponseImage(data=img_data, seed=opt_seed)) res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig)
if opt_save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')): (opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
@ -375,13 +400,15 @@ def mk_img(req: Request):
filtered_image = Image.fromarray(x_sample) filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image) filtered_img_data = img_to_base64_str(filtered_image)
res.images.append(ResponseImage(data=filtered_img_data, seed=opt_seed)) res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(res_image_filtered)
filters_applied = "_".join(filters_applied) filters_applied = "_".join(filters_applied)
if opt_save_to_disk_path is not None: if opt_save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}") filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
seeds += str(opt_seed) + "," seeds += str(opt_seed) + ","
opt_seed += 1 opt_seed += 1
@ -411,7 +438,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": if device != "cpu":
@ -430,12 +457,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=opt_ddim_eta, eta=opt_ddim_eta,
x_T=start_code, x_T=start_code,
img_callback=img_callback,
sampler = 'plms', sampler = 'plms',
) )
return samples_ddim return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = model.stochastic_encode(
init_latent, init_latent,
@ -451,6 +479,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
z_enc, z_enc,
unconditional_guidance_scale=opt_scale, unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
img_callback=img_callback,
sampler = 'ddim' sampler = 'ddim'
) )
@ -479,13 +508,18 @@ def load_model_from_config(ckpt, verbose=False):
return sd return sd
# utils # utils
class UserInitiatedStop(Exception):
pass
def load_img(img_str): def load_img(img_str, w0, h0):
image = base64_str_to_img(img_str).convert("RGB") image = base64_str_to_img(img_str).convert("RGB")
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64") print(f"loaded input image of size ({w}, {h}) from base64")
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.LANCZOS) image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)

View File

@ -83,7 +83,7 @@ async def ping():
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.post('/image') @app.post('/image')
async def image(req : ImageRequest): def image(req : ImageRequest):
from sd_internal import runtime from sd_internal import runtime
r = Request() r = Request()
@ -114,6 +114,20 @@ async def image(req : ImageRequest):
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.get('/image/stop')
def stop():
try:
if model_is_loading:
return {'ERROR'}
from sd_internal import runtime
runtime.stop_processing = True
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
try: try: