diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 9025f988..df044e2b 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [ "diskPath", "sound_toggle", "turbo", - "use_full_precision", "confirm_dangerous_actions", "auto_save_settings", "apply_color_correction" @@ -278,7 +277,6 @@ function tryLoadOldSettings() { "soundEnabled": "sound_toggle", "saveToDisk": "save_to_disk", "useCPU": "use_cpu", - "useFullPrecision": "use_full_precision", "useTurboMode": "turbo", "diskPath": "diskPath", "useFaceCorrection": "use_face_correction", diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 53b55ac7..8b9819c2 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -217,13 +217,6 @@ const TASK_MAPPING = { readUI: () => turboField.checked, parse: (val) => Boolean(val) }, - use_full_precision: { name: 'Use Full Precision', - setUI: (use_full_precision) => { - useFullPrecisionField.checked = use_full_precision - }, - readUI: () => useFullPrecisionField.checked, - parse: (val) => Boolean(val) - }, stream_image_progress: { name: 'Stream Image Progress', setUI: (stream_image_progress) => { @@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler) const TASK_REQ_NO_EXPORT = [ "use_cpu", "turbo", - "use_full_precision", "save_to_disk_path" ] const resetSettings = document.getElementById('reset-image-settings') diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index dd34ddb1..61822c0f 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -728,7 +728,6 @@ "stream_image_progress": 'boolean', "show_only_filtered_image": 'boolean', "turbo": 'boolean', - "use_full_precision": 'boolean', "output_format": 'string', "output_quality": 'number', } @@ -744,7 +743,6 @@ "stream_image_progress": true, "show_only_filtered_image": true, "turbo": false, - "use_full_precision": false, "output_format": "png", "output_quality": 75, } diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 6c08aef7..e058f9e9 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -850,7 +850,6 @@ function getCurrentUserRequest() { // allow_nsfw: allowNSFWField.checked, turbo: turboField.checked, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. - use_full_precision: useFullPrecisionField.checked, use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, stream_progress_updates: true, diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 3643d1ec..f326953b 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -105,14 +105,6 @@ var PARAMETERS = [ note: "to process in parallel", default: false, }, - { - id: "use_full_precision", - type: ParameterType.checkbox, - label: "Use Full Precision", - note: "for GPU-only. warning: this will consume more VRAM", - icon: "fa-crosshairs", - default: false, - }, { id: "auto_save_settings", type: ParameterType.checkbox, @@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo') let useCPUField = document.querySelector('#use_cpu') let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let useGPUsField = document.querySelector('#use_gpus') -let useFullPrecisionField = document.querySelector('#use_full_precision') let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 9dd4f066..a06220ca 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,93 +1,22 @@ -import json +from pydantic import BaseModel -class Request: +from modules.types import GenerateImageRequest + +class TaskData(BaseModel): request_id: str = None session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - apply_color_correction = False - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - precision: str = "autocast" # or "full" save_to_disk_path: str = None turbo: bool = True - use_full_precision: bool = False use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" use_vae_model: str = None use_hypernetwork_model: str = None - hypernetwork_strength: float = 1 show_only_filtered_image: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 - - stream_progress_updates: bool = False stream_image_progress: bool = False - def json(self): - return { - "request_id": self.request_id, - "session_id": self.session_id, - "prompt": self.prompt, - "negative_prompt": self.negative_prompt, - "num_outputs": self.num_outputs, - "num_inference_steps": self.num_inference_steps, - "guidance_scale": self.guidance_scale, - "width": self.width, - "height": self.height, - "seed": self.seed, - "prompt_strength": self.prompt_strength, - "sampler": self.sampler, - "apply_color_correction": self.apply_color_correction, - "use_face_correction": self.use_face_correction, - "use_upscale": self.use_upscale, - "use_stable_diffusion_model": self.use_stable_diffusion_model, - "use_vae_model": self.use_vae_model, - "use_hypernetwork_model": self.use_hypernetwork_model, - "hypernetwork_strength": self.hypernetwork_strength, - "output_format": self.output_format, - "output_quality": self.output_quality, - } - - def __str__(self): - return f''' - session_id: {self.session_id} - prompt: {self.prompt} - negative_prompt: {self.negative_prompt} - seed: {self.seed} - num_inference_steps: {self.num_inference_steps} - sampler: {self.sampler} - guidance_scale: {self.guidance_scale} - w: {self.width} - h: {self.height} - precision: {self.precision} - save_to_disk_path: {self.save_to_disk_path} - turbo: {self.turbo} - use_full_precision: {self.use_full_precision} - apply_color_correction: {self.apply_color_correction} - use_face_correction: {self.use_face_correction} - use_upscale: {self.use_upscale} - use_stable_diffusion_model: {self.use_stable_diffusion_model} - use_vae_model: {self.use_vae_model} - use_hypernetwork_model: {self.use_hypernetwork_model} - hypernetwork_strength: {self.hypernetwork_strength} - show_only_filtered_image: {self.show_only_filtered_image} - output_format: {self.output_format} - output_quality: {self.output_quality} - - stream_progress_updates: {self.stream_progress_updates} - stream_image_progress: {self.stream_image_progress}''' - class Image: data: str # base64 seed: int @@ -106,17 +35,20 @@ class Image: } class Response: - request: Request + render_request: GenerateImageRequest + task_data: TaskData images: list - def __init__(self, request: Request, images: list): - self.request = request + def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): + self.render_request = render_request + self.task_data = task_data self.images = images def json(self): res = { "status": 'succeeded', - "request": self.request.json(), + "render_request": self.render_request.dict(), + "task_data": self.task_data.dict(), "output": [], } diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index a3f91cfb..733bab50 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -6,6 +6,13 @@ import logging log = logging.getLogger() +''' +Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). +Otherwise the models will load at half-precision (i.e. float16). + +Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). +''' + COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked mem_free_threshold = 0 @@ -96,7 +103,7 @@ def device_init(context, device): if device == 'cpu': context.device = 'cpu' context.device_name = get_processor_name() - context.precision = 'full' + context.half_precision = False log.debug(f'Render device CPU available as {context.device_name}') return @@ -107,7 +114,7 @@ def device_init(context, device): if needs_to_force_full_precision(context): log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. - context.precision = 'full' + context.half_precision = False log.info(f'Setting {device} as active') torch.cuda.device(device) @@ -115,6 +122,9 @@ def device_init(context, device): return def needs_to_force_full_precision(context): + if 'FORCE_FULL_PRECISION' in os.environ: + return True + device_name = context.device_name.lower() return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index ebeb439c..2129af3d 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -3,8 +3,7 @@ import logging import picklescan.scanner import rich -from sd_internal import app, device_manager -from sd_internal import Request +from sd_internal import app log = logging.getLogger() @@ -157,19 +156,3 @@ def getModels(): models['options']['stable-diffusion'].append('custom-model') return models - -def is_sd_model_reload_necessary(thread_data, req: Request): - needs_model_reload = False - if 'stable-diffusion' not in thread_data.models or \ - thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \ - thread_data.model_paths['vae'] != req.use_vae_model: - - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and req.use_full_precision) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - return needs_model_reload diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index b7ce6fe5..434f489b 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -9,13 +9,14 @@ import traceback import logging from sd_internal import device_manager, model_manager -from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop +from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters +from modules.types import Context, GenerateImageRequest log = logging.getLogger() -thread_data = threading.local() +thread_data = Context() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' @@ -28,14 +29,7 @@ def init(device): ''' thread_data.stop_processing = False thread_data.temp_images = {} - - thread_data.models = {} - thread_data.model_paths = {} - - thread_data.device = None - thread_data.device_name = None - thread_data.precision = 'autocast' - thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS') + thread_data.partial_x_samples = None device_manager.device_init(thread_data, device) @@ -51,16 +45,16 @@ def load_default_models(): # load mandatory models model_loader.load_model(thread_data, 'stable-diffusion') -def reload_models_if_necessary(req: Request): +def reload_models_if_necessary(task_data: TaskData): model_paths_in_req = ( - ('hypernetwork', req.use_hypernetwork_model), - ('gfpgan', req.use_face_correction), - ('realesrgan', req.use_upscale), + ('hypernetwork', task_data.use_hypernetwork_model), + ('gfpgan', task_data.use_face_correction), + ('realesrgan', task_data.use_upscale), ) - if model_manager.is_sd_model_reload_necessary(thread_data, req): - thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model - thread_data.model_paths['vae'] = req.use_vae_model + if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model: + thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model + thread_data.model_paths['vae'] = task_data.use_vae_model model_loader.load_model(thread_data, 'stable-diffusion') @@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request): else: model_loader.unload_model(thread_data, model_type) -def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): +def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - log.info(req) - return _make_images_internal(req, data_queue, task_temp_images, step_callback) + # resolve the model paths to use + resolve_model_paths(task_data) + + # convert init image to PIL.Image + req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None + req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None + + # generate + return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) except Exception as e: log.error(traceback.format_exc()) @@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s })) raise e -def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - args = req_to_args(req) +def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + print(metadata) - images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress) - images = apply_color_correction(args, images, user_stopped) - images = apply_filters(args, images, user_stopped, req.show_only_filtered_image) + images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + images = apply_color_correction(req, images, user_stopped) + images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) - if req.save_to_disk_path is not None: - out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) - save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image) + if task_data.save_to_disk_path is not None: + out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image) - res = Response(req, images=construct_response(req, images)) + res = Response(req, task_data, images=construct_response(images)) res = res.json() data_queue.put(json.dumps(res)) log.info('Task completed') return res -def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): +def resolve_model_paths(task_data: TaskData): + task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') + task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae') + task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + + if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') + +def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): thread_data.temp_images.clear() - image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress) + image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=thread_data, args=args) + images = image_generator.make_images(context=thread_data, req=req) user_stopped = False except UserInitiatedStop: images = [] user_stopped = True - if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: - return images - for i in range(args['num_outputs']): - images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) - - del thread_data.partial_x_samples + if thread_data.partial_x_samples is not None: + for i in range(req.num_outputs): + images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) + + thread_data.partial_x_samples = None finally: model_loader.gc(thread_data) - images = [(image, args['seed'] + i, False) for i, image in enumerate(images)] + images = [(image, req.seed + i, False) for i, image in enumerate(images)] return images, user_stopped -def apply_color_correction(args: dict, images: list, user_stopped): - if user_stopped or args['init_image'] is None or not args['apply_color_correction']: +def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped): + if user_stopped or req.init_image is None or not req.apply_color_correction: return images for i, img_info in enumerate(images): img, seed, filtered = img_info - img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img) + img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img) images[i] = (img, seed, filtered) return images -def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image): - if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None): +def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image): + if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images filters = [] - if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan) - if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan) + if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan) + if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] for img, seed, _ in images: @@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte img_path += '.' + metadata['output_format'] img.save(img_path, quality=metadata['output_quality']) -def construct_response(req: Request, images: list): +def construct_response(task_data: TaskData, images: list): return [ ResponseImage( - data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality), + data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), seed=seed ) for img, seed, _ in images ] -def req_to_args(req: Request): - args = req.json() - - args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None - args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None - - return args - -def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): - n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength']) +def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - for i in range(args['num_outputs']): + for i in range(req.num_outputs): img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) buf = image_utils.img_to_buffer(img, output_format='JPEG') del img - thread_data.temp_images[f"{args['request_id']}/{i}"] = buf + thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf - partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"}) + partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) return partial_images def on_image_step(x_samples, i): diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 9d1d26f6..3ec1b99d 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -14,8 +14,8 @@ import torch import queue, threading, time, weakref from typing import Any, Hashable -from pydantic import BaseModel -from sd_internal import Request, device_manager +from sd_internal import TaskData, device_manager +from modules.types import GenerateImageRequest log = logging.getLogger() @@ -39,9 +39,10 @@ class ServerStates: class Unavailable(Symbol): pass class RenderTask(): # Task with output queue and completion lock. - def __init__(self, req: Request): + def __init__(self, req: GenerateImageRequest, task_data: TaskData): req.request_id = id(self) - self.request: Request = req # Initial Request + self.render_request: GenerateImageRequest = req # Initial Request + self.task_data: TaskData = task_data self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) @@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock. def is_pending(self): return bool(not self.response and not self.error) -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - apply_color_correction: bool = False - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False ##TODO Remove after UI and plugins transition. - render_device: str = None # Select the task affinity. (Not used to change active devices). - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - use_vae_model: str = None - use_hypernetwork_model: str = None - hypernetwork_strength: float = None - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - - stream_progress_updates: bool = False - stream_image_progress: bool = False - -class FilterRequest(BaseModel): - session_id: str = "session" - model: str = None - name: str = "" - init_image: str = None # base64 - width: int = 512 - height: int = 512 - save_to_disk_path: str = None - turbo: bool = True - render_device: str = None - use_full_precision: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache(): def __init__(self): @@ -311,7 +263,7 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') + log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): @@ -322,16 +274,16 @@ def thread_render(device): if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel - runtime2.reload_models_if_necessary(task.request) + runtime2.reload_models_if_necessary(task.task_data) current_state = ServerStates.Rendering - task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = e log.error(traceback.format_exc()) @@ -340,13 +292,13 @@ def thread_render(device): # Task completed task.lock.release() task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - log.info(f'Session {task.request.session_id} task {id(task)} cancelled!') + log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') elif task.error is not None: - log.info(f'Session {task.request.session_id} task {id(task)} failed!') + log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') else: - log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') + log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): @@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown global current_state_error current_state_error = SystemExit('Application shutting down.') -def render(req : ImageRequest): +def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') # Alive, check if task in cache - session = get_cached_session(req.session_id, update_ttl=True) + session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): - raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.apply_color_correction = req.apply_color_correction - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.use_stable_diffusion_model = req.use_stable_diffusion_model - r.use_vae_model = req.use_vae_model - r.use_hypernetwork_model = req.use_hypernetwork_model - r.hypernetwork_strength = req.hypernetwork_strength - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - r.output_quality = req.output_quality - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - if not req.stream_progress_updates: - r.stream_image_progress = False - - new_task = RenderTask(r) + new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. diff --git a/ui/server.py b/ui/server.py index 258c433d..42bdb0f4 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from sd_internal import app, model_manager, task_manager +from sd_internal import TaskData +from modules.types import GenerateImageRequest log = logging.getLogger() @@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}') server_api = FastAPI() -# don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} class NoCacheStaticFiles(StaticFiles): @@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel): listen_port: int = None test_sd2: bool = None -class LogSuppressFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - path = record.getMessage() - for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: - if path.find(prefix) != -1: - return False - return True - -# don't log certain requests -logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) - server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: @@ -137,20 +126,18 @@ def ping(session_id:str=None): return JSONResponse(response, headers=NOCACHE_HEADERS) @server_api.post('/render') -def render(req : task_manager.ImageRequest): +def render(req: dict): try: - app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + # separate out the request data into rendering and task-specific data + render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) + task_data: TaskData = TaskData.parse_obj(req) - # resolve the model paths to use - req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') - req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') - req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') + render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision - if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan') - if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan') + app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model) # enqueue the task - new_task = task_manager.render(req) + new_task = task_manager.render(render_req, task_data) response = { 'status': str(task_manager.current_state), 'queue': len(task_manager.tasks_queue),