diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 2bcf56d2..d8981226 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp -if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion" if not exist "..\models\vae" mkdir "..\models\vae" -if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork" -echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt" -echo. > "..\models\vae\Put your VAE files here.txt" -echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt" @if exist "sd-v1-4.ckpt" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 8682c5cc..3e1c9c58 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -159,12 +159,7 @@ fi -mkdir -p "../models/stable-diffusion" mkdir -p "../models/vae" -mkdir -p "../models/hypernetwork" -echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt" -echo "" > "../models/vae/Put your VAE files here.txt" -echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt" if [ -f "sd-v1-4.ckpt" ]; then model_size=`find "sd-v1-4.ckpt" -printf "%s"` diff --git a/ui/index.html b/ui/index.html index d3fb6da3..0094201b 100644 --- a/ui/index.html +++ b/ui/index.html @@ -409,7 +409,6 @@ async function init() { await initSettings() await getModels() - await getDiskPath() await getAppConfig() await loadUIPlugins() await loadModifiers() diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 5c522d7e..3643d1ec 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() { gpuSettingEntry.style.display = (this.checked ? 'none' : '') }) -async function getDiskPath() { - try { - var diskPath = getSetting("diskPath") - if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { - let res = await fetch('/get/output_dir') - if (res.status === 200) { - res = await res.json() - res = res.output_dir - - setSetting("diskPath", res) - } - } - } catch (e) { - console.log('error fetching output dir path', e) +async function setDiskPath(defaultDiskPath) { + var diskPath = getSetting("diskPath") + if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { + setSetting("diskPath", defaultDiskPath) } } @@ -415,6 +405,7 @@ async function getSystemInfo() { setDeviceInfo(devices) setHostInfo(res['hosts']) + setDiskPath(res['default_output_dir']) } catch (e) { console.log('error fetching devices', e) } diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index bf16bb5e..b001d3f9 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -105,6 +105,10 @@ class Response: request: Request images: list + def __init__(self, request: Request, images: list): + self.request = request + self.images = images + def json(self): res = { "status": 'succeeded', @@ -116,3 +120,6 @@ class Response: res["output"].append(image.json()) return res + +class UserInitiatedStop(Exception): + pass diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 00d2e718..d1cec46f 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = { 'open_browser_on_start': True, }, } -DEFAULT_MODELS = [ - # needed to support the legacy installations - 'custom-model', # Check if user has a custom model, use it first. - 'sd-v1-4', # Default fallback. -] def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 3c1d8bf1..4de6b265 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -101,10 +101,8 @@ def device_init(context, device): context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - device_name = context.device_name.lower() - force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) - if force_full_precision: - print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name) + if needs_to_force_full_precision(context.device_name): + print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. context.precision = 'full' @@ -113,6 +111,10 @@ def device_init(context, device): return +def needs_to_force_full_precision(context): + device_name = context.device_name.lower() + return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) + def validate_device_id(device, log_prefix=''): def is_valid(): if not isinstance(device, str): diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 906038e1..a8b249a2 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,32 +1,39 @@ import os -from sd_internal import app +from sd_internal import app, device_manager +from sd_internal import Request import picklescan.scanner import rich -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - -default_model_to_load = None -default_vae_to_load = None -default_hypernetwork_to_load = None +KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] +MODEL_EXTENSIONS = { + 'stable-diffusion': ['.ckpt', '.safetensors'], + 'vae': ['.vae.pt', '.ckpt'], + 'hypernetwork': ['.pt'], + 'gfpgan': ['.pth'], + 'realesrgan': ['.pth'], +} +DEFAULT_MODELS = { + 'stable-diffusion': [ # needed to support the legacy installations + 'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' + 'sd-v1-4', # Default fallback. + ], + 'gfpgan': ['GFPGANv1.3'], + 'realesrgan': ['RealESRGAN_x4plus'], +} known_models = {} def init(): - global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load - - default_model_to_load = resolve_ckpt_to_use() - default_vae_to_load = resolve_vae_to_use() - default_hypernetwork_to_load = resolve_hypernetwork_to_use() - + make_model_folders() getModels() # run this once, to cache the picklescan results -def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): +def resolve_model_to_use(model_name:str, model_type:str): + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR] + model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] if not model_name: # When None try user configured model. # config = getConfig() if 'model' in config and model_type in config['model']: @@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex model_name = 'sd-v1-4' # Check models directory - models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name) + models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name) for model_extension in model_extensions: if os.path.exists(models_dir_path + model_extension): return models_dir_path + model_extension @@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex print(f'No valid models found for model_name: {model_name}') return None -def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) +def resolve_sd_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='stable-diffusion') -def resolve_vae_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) +def resolve_vae_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='vae') -def resolve_hypernetwork_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) +def resolve_hypernetwork_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='hypernetwork') + +def resolve_gfpgan_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='gfpgan') + +def resolve_realesrgan_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='realesrgan') + +def make_model_folders(): + for model_type in KNOWN_MODEL_TYPES: + model_dir_path = os.path.join(app.MODELS_DIR, model_type) + + os.makedirs(model_dir_path, exist_ok=True) + + help_file_name = f'Place your {model_type} model files here.txt' + help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' + + with open(os.path.join(model_dir_path, help_file_name)) as f: + f.write(help_file_contents) def is_malicious_model(file_path): try: @@ -102,8 +127,9 @@ def getModels(): }, } - def listModels(models_dirname, model_type, model_extensions): - models_dir = os.path.join(app.MODELS_DIR, models_dirname) + def listModels(model_type): + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + models_dir = os.path.join(app.MODELS_DIR, model_type) if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -128,9 +154,9 @@ def getModels(): models['options'][model_type].sort() # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) + listModels(model_type='stable-diffusion') + listModels(model_type='vae') + listModels(model_type='hypernetwork') # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') @@ -138,3 +164,19 @@ def getModels(): models['options']['stable-diffusion'].append('custom-model') return models + +def is_sd_model_reload_necessary(thread_data, req: Request): + needs_model_reload = False + if 'stable-diffusion' not in thread_data.models or \ + thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \ + thread_data.model_paths['vae'] != req.use_vae_model: + + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and req.use_full_precision) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index fc8d944d..840cec09 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,16 +1,23 @@ import threading import queue +import time +import json +import os +import base64 +import re from sd_internal import device_manager, model_manager -from sd_internal import Request, Response, Image as ResponseImage +from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils +from modules import model_loader, image_generator, image_utils, image_filters thread_data = threading.local() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' +filename_regex = re.compile('[^a-zA-Z0-9]') + def init(device): ''' Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device @@ -28,89 +35,167 @@ def init(device): device_manager.device_init(thread_data, device) - load_default_models() + init_and_load_default_models() def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) -def load_default_models(): - thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load - thread_data.model_paths['vae'] = model_manager.default_vae_to_load +def init_and_load_default_models(): + # init default model paths + thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use() + thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use() + thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use() + thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use() + thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() + # load mandatory models model_loader.load_sd_model(thread_data) -def reload_models_if_necessary(req: Request=None): - needs_model_reload = False - if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True +def reload_models_if_necessary(req: Request): + if model_manager.is_sd_model_reload_necessary(thread_data, req): + thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model + thread_data.model_paths['vae'] = req.use_vae_model - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True + model_loader.load_sd_model(thread_data) - return needs_model_reload + # if is_hypernetwork_reload_necessary(task.request): + # current_state = ServerStates.LoadingModel + # runtime.reload_hypernetwork() - if is_hypernetwork_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_hypernetwork() +def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) + images = apply_filters(req, images, user_stopped) - if is_model_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_model() + save_images(req, images) -def load_models(): - if ckpt_file_path == None: - ckpt_file_path = default_model_to_load - if vae_file_path == None: - vae_file_path = default_vae_to_load - if hypernetwork_file_path == None: - hypernetwork_file_path = default_hypernetwork_to_load - if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: - return - current_state = ServerStates.LoadingModel - try: - from sd_internal import runtime2 - runtime.thread_data.hypernetwork_file = hypernetwork_file_path - runtime.thread_data.ckpt_file = ckpt_file_path - runtime.thread_data.vae_file = vae_file_path - runtime.load_model_ckpt() - runtime.load_hypernetwork() - current_model_path = ckpt_file_path - current_vae_path = vae_file_path - current_hypernetwork_path = hypernetwork_file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_vae_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) + return Response(req, images=construct_response(req, images)) + +def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + thread_data.temp_images.clear() + + image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback) -def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) + user_stopped = False except UserInitiatedStop: - pass + images = [] + user_stopped = True + if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: + return images + for i in range(req.num_outputs): + images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) + + del thread_data.partial_x_samples + finally: + model_loader.gc(thread_data) + + images = [(image, req.seed + i, False) for i, image in enumerate(images)] + + return images, user_stopped + +def apply_filters(req: Request, images: list, user_stopped): + if user_stopped or (req.use_face_correction is None and req.use_upscale is None): + return images + + filters = [] + if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction))) + if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale))) + + filtered_images = [] + for img, seed, _ in images: + for filter_fn, filter_model_path in filters: + img = filter_fn(thread_data, img, filter_model_path) + + filtered_images.append((img, seed, True)) + + if not req.show_only_filtered_image: + filtered_images = images + filtered_images + + return filtered_images + +def save_images(req: Request, images: list): + if req.save_to_disk_path is None: + return + + def get_image_id(i): + img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. + img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. + return img_id + + def get_image_basepath(i): + session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) + os.makedirs(session_out_path, exist_ok=True) + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}") + + for i, img_data in enumerate(images): + img, seed, filtered = img_data + img_path = get_image_basepath(i) + + if not filtered or req.show_only_filtered_image: + img_metadata_path = img_path + '.txt' + metadata = req.json() + metadata['seed'] = seed + with open(img_metadata_path, 'w', encoding='utf-8') as f: + f.write(metadata) + + img_path += '_filtered' if filtered else '' + img_path += '.' + req.output_format + img.save(img_path, quality=req.output_quality) + +def construct_response(req: Request, images: list): + return [ + ResponseImage( + data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality), + seed=seed + ) for img, seed, _ in images + ] def get_mk_img_args(req: Request): args = req.json() - if req.init_image is not None: - args['init_image'] = image_utils.base64_str_to_img(req.init_image) - - if req.mask is not None: - args['mask'] = image_utils.base64_str_to_img(req.mask) + args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None + args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None return args -def on_image_step(x_samples, i): - pass +def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) + last_callback_time = -1 -image_generator.on_image_step = on_image_step + def update_temp_img(req, x_samples, task_temp_images: list): + partial_images = [] + for i in range(req.num_outputs): + img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) + buf = image_utils.img_to_buffer(img, output_format='JPEG') + + del img + + thread_data.temp_images[f'{req.request_id}/{i}'] = buf + task_temp_images[i] = buf + partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) + return partial_images + + def on_image_step(x_samples, i): + nonlocal last_callback_time + + thread_data.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 + last_callback_time = time.time() + + progress = {"step": i, "step_time": step_time, "total_steps": n_steps} + + if req.stream_image_progress and i % 5 == 0: + progress['output'] = update_temp_img(req, x_samples, task_temp_images) + + data_queue.put(json.dumps(progress)) + + step_callback() + + if thread_data.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_image_step diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 04cc9a69..c6ab9737 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -324,7 +324,7 @@ def thread_render(device): runtime2.reload_models_if_necessary(task.request) current_state = ServerStates.Rendering - task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) diff --git a/ui/server.py b/ui/server.py index 1cb4c910..fd2557e2 100644 --- a/ui/server.py +++ b/ui/server.py @@ -137,9 +137,9 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model) - req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model) - req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model) + req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model) + req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model) new_task = task_manager.render(req) response = {