"""task_manager.py: manage tasks dispatching and render threads. Notes: render_threads should be the only hard reference held by the manager to the threads. Use weak_thread_data to store all other data using weak keys. This will allow for garbage collection after the thread dies. """ import json import traceback TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref from typing import Any, Generator, Hashable, Optional, Union from pydantic import BaseModel from sd_internal import Request, Response, runtime, device_manager THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. class SymbolClass(type): # Print nicely formatted Symbol names. def __repr__(self): return self.__qualname__ def __str__(self): return self.__name__ class Symbol(metaclass=SymbolClass): pass class ServerStates: class Init(Symbol): pass class LoadingModel(Symbol): pass class Online(Symbol): pass class Rendering(Symbol): pass class Unavailable(Symbol): pass class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: Request): req.request_id = id(self) self.request: Request = req # Initial Request self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments async def read_buffer_generator(self): try: while not self.buffer_queue.empty(): res = self.buffer_queue.get(block=False) self.buffer_queue.task_done() yield res except queue.Empty as e: yield @property def status(self): if self.lock.locked(): return 'running' if isinstance(self.error, StopAsyncIteration): return 'stopped' if self.error: return 'error' if not self.buffer_queue.empty(): return 'buffer' if self.response: return 'completed' return 'pending' @property def is_pending(self): return bool(not self.response and not self.error) # defaults from https://huggingface.co/blog/stable_diffusion class ImageRequest(BaseModel): session_id: str = "session" prompt: str = "" negative_prompt: str = "" init_image: str = None # base64 mask: str = None # base64 num_outputs: int = 1 num_inference_steps: int = 50 guidance_scale: float = 7.5 width: int = 512 height: int = 512 seed: int = 42 prompt_strength: float = 0.8 sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" # allow_nsfw: bool = False save_to_disk_path: str = None turbo: bool = True use_cpu: bool = False ##TODO Remove after UI and plugins transition. render_device: str = None # Select the task affinity. (Not used to change active devices). use_full_precision: bool = False use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" use_vae_model: str = None use_hypernetwork_model: str = None hypernetwork_strength: float = None show_only_filtered_image: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 stream_progress_updates: bool = False stream_image_progress: bool = False class FilterRequest(BaseModel): session_id: str = "session" model: str = None name: str = "" init_image: str = None # base64 width: int = 512 height: int = 512 save_to_disk_path: str = None turbo: bool = True render_device: str = None use_full_precision: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache(): def __init__(self): self._base = dict() self._lock: threading.Lock = threading.Lock() def _get_ttl_time(self, ttl: int) -> int: return int(time.time()) + ttl def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp def clean(self) -> None: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) try: # Create a list of expired keys to delete to_delete = [] for key in self._base: ttl, _ = self._base[key] if self._is_expired(ttl): to_delete.append(key) # Remove Items for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): print(f'RenderTask {key} expired. Data removed.') elif isinstance(val, SessionState): print(f'Session {key} expired. Data removed.') else: print(f'Key {key} expired. Data removed.') del self._base[key] finally: self._lock.release() def clear(self) -> None: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) try: self._base.clear() finally: self._lock.release() def delete(self, key: Hashable) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) try: if key not in self._base: return False del self._base[key] return True finally: self._lock.release() def keep(self, key: Hashable, ttl: int) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) try: if key in self._base: _, value = self._base.get(key) self._base[key] = (self._get_ttl_time(ttl), value) return True return False finally: self._lock.release() def put(self, key: Hashable, value: Any, ttl: int) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) try: self._base[key] = ( self._get_ttl_time(ttl), value ) except Exception as e: print(str(e)) print(traceback.format_exc()) return False else: return True finally: self._lock.release() def tryGet(self, key: Hashable) -> Any: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): print(f'Session {key} expired. Discarding data.') del self._base[key] return None return value finally: self._lock.release() manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init current_state_error:Exception = None current_model_path = None current_vae_path = None current_hypernetwork_path = None tasks_queue = [] session_cache = DataCache() task_cache = DataCache() default_model_to_load = None default_vae_to_load = None default_hypernetwork_to_load = None weak_thread_data = weakref.WeakKeyDictionary() idle_event: threading.Event = threading.Event() class SessionState(): def __init__(self, id: str): self._id = id self._tasks_ids = [] @property def id(self): return self._id @property def tasks(self): tasks = [] for task_id in self._tasks_ids: task = task_cache.tryGet(task_id) if task: tasks.append(task) return tasks def put(self, task, ttl=TASK_TTL): task_id = id(task) self._tasks_ids.append(task_id) if not task_cache.put(task_id, task, ttl): return False while len(self._tasks_ids) > len(render_threads) * 2: self._tasks_ids.pop(0) return True def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None): global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path if ckpt_file_path == None: ckpt_file_path = default_model_to_load if vae_file_path == None: vae_file_path = default_vae_to_load if hypernetwork_file_path == None: hypernetwork_file_path = default_hypernetwork_to_load if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: return current_state = ServerStates.LoadingModel try: from . import runtime runtime.thread_data.hypernetwork_file = hypernetwork_file_path runtime.thread_data.ckpt_file = ckpt_file_path runtime.thread_data.vae_file = vae_file_path runtime.load_model_ckpt() runtime.load_hypernetwork() current_model_path = ckpt_file_path current_vae_path = vae_file_path current_hypernetwork_path = hypernetwork_file_path current_state_error = None current_state = ServerStates.Online except Exception as e: current_model_path = None current_vae_path = None current_state_error = e current_state = ServerStates.Unavailable print(traceback.format_exc()) def thread_get_next_task(): from . import runtime if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() return None task = None try: # Select a render task. for queued_task in tasks_queue: if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. else: # Requested device is not active, return error to UI. queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task break if task is not None: del tasks_queue[tasks_queue.index(task)] return task finally: manager_lock.release() def thread_render(device): global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path from . import runtime try: runtime.thread_init(device) except Exception as e: print(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } return weak_thread_data[threading.current_thread()] = { 'device': runtime.thread_data.device, 'device_name': runtime.thread_data.device_name, 'alive': True } if runtime.thread_data.device != 'cpu' or is_alive() == 1: preload_model() current_state = ServerStates.Online while True: session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: print(f'Shutting down thread for device {runtime.thread_data.device}') runtime.unload_models() runtime.unload_filters() return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable return task = thread_get_next_task() if task is None: idle_event.clear() idle_event.wait(timeout=1) continue if task.error is not None: print(task.error) task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue if current_state_error: task.error = current_state_error task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: if runtime.is_hypernetwork_reload_necessary(task.request): runtime.reload_hypernetwork() current_hypernetwork_path = task.request.use_hypernetwork_model if runtime.is_model_reload_necessary(task.request): current_state = ServerStates.LoadingModel runtime.reload_model() current_model_path = task.request.use_stable_diffusion_model current_vae_path = task.request.use_vae_model def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): runtime.thread_data.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.Rendering task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) except Exception as e: task.error = e task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) print(traceback.format_exc()) continue finally: # Task completed task.lock.release() task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): print(f'Session {task.request.session_id} task {id(task)} cancelled!') elif task.error is not None: print(f'Session {task.request.session_id} task {id(task)} failed!') else: print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): # By calling keep before tryGet, wont discard if was expired. if update_ttl and not task_cache.keep(task_id, TASK_TTL): # Failed to keep task, already gone. return None return task_cache.tryGet(task_id) def get_cached_session(session_id:str, update_ttl:bool=False): if update_ttl: session_cache.keep(session_id, TASK_TTL) session = session_cache.tryGet(session_id) if not session: session = SessionState(session_id) session_cache.put(session_id, session, TASK_TTL) return session def get_devices(): devices = { 'all': {}, 'active': {}, } def get_device_info(device): if device == 'cpu': return {'name': device_manager.get_processor_name()} mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) return { 'name': torch.cuda.get_device_name(device), 'mem_free': mem_free, 'mem_total': mem_total, } # list the compatible devices gpu_count = torch.cuda.device_count() for device in range(gpu_count): device = f'cuda:{device}' if not device_manager.is_device_compatible(device): continue devices['all'].update({device: get_device_info(device)}) devices['all'].update({'cpu': get_device_info('cpu')}) # list the activated devices if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED) try: for rthread in render_threads: if not rthread.is_alive(): continue weak_data = weak_thread_data.get(rthread) if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data: continue device = weak_data['device'] devices['active'].update({device: get_device_info(device)}) finally: manager_lock.release() return devices def is_alive(device=None): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) nbr_alive = 0 try: for rthread in render_threads: if device is not None: weak_data = weak_thread_data.get(rthread) if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: continue thread_device = weak_data['device'] if thread_device != device: continue if rthread.is_alive(): nbr_alive += 1 return nbr_alive finally: manager_lock.release() def start_render_thread(device): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) print('Start new Rendering Thread on device', device) try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True rthread.name = THREAD_NAME_PREFIX + device rthread.start() render_threads.append(rthread) finally: manager_lock.release() timeout = DEVICE_START_TIMEOUT while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: print(rthread, device, 'error:', weak_thread_data[rthread]['error']) return False if timeout <= 0: return False timeout -= 1 time.sleep(1) return True def stop_render_thread(device): try: device_manager.validate_device_id(device, log_prefix='stop_render_thread') except: print(traceback.format_exc()) return False if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) print('Stopping Rendering Thread on device', device) try: thread_to_remove = None for rthread in render_threads: weak_data = weak_thread_data.get(rthread) if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: continue thread_device = weak_data['device'] if thread_device == device: weak_data['alive'] = False thread_to_remove = rthread break if thread_to_remove is not None: render_threads.remove(rthread) return True finally: manager_lock.release() return False def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) print('devices_to_start', devices_to_start) print('devices_to_stop', devices_to_stop) for device in devices_to_stop: if is_alive(device) <= 0: print(device, 'is not alive') continue if not stop_render_thread(device): print(device, 'could not stop render thread') for device in devices_to_start: if is_alive(device) >= 1: print(device, 'already registered.') continue if not start_render_thread(device): print(device, 'failed to start.') if is_alive() <= 0: # No running devices, probably invalid user config. raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') print('active devices', get_devices()['active']) def shutdown_event(): # Signal render thread to close on shutdown global current_state_error current_state_error = SystemExit('Application shutting down.') def render(req : ImageRequest): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') # Alive, check if task in cache session = get_cached_session(req.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') from . import runtime r = Request() r.session_id = req.session_id r.prompt = req.prompt r.negative_prompt = req.negative_prompt r.init_image = req.init_image r.mask = req.mask r.num_outputs = req.num_outputs r.num_inference_steps = req.num_inference_steps r.guidance_scale = req.guidance_scale r.width = req.width r.height = req.height r.seed = req.seed r.prompt_strength = req.prompt_strength r.sampler = req.sampler # r.allow_nsfw = req.allow_nsfw r.turbo = req.turbo r.use_full_precision = req.use_full_precision r.save_to_disk_path = req.save_to_disk_path r.use_upscale: str = req.use_upscale r.use_face_correction = req.use_face_correction r.use_stable_diffusion_model = req.use_stable_diffusion_model r.use_vae_model = req.use_vae_model r.use_hypernetwork_model = req.use_hypernetwork_model r.hypernetwork_strength = req.hypernetwork_strength r.show_only_filtered_image = req.show_only_filtered_image r.output_format = req.output_format r.output_quality = req.output_quality r.stream_progress_updates = True # the underlying implementation only supports streaming r.stream_image_progress = req.stream_image_progress if not req.stream_progress_updates: r.stream_image_progress = False new_task = RenderTask(r) if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): try: tasks_queue.append(new_task) idle_event.set() return new_task finally: manager_lock.release() raise RuntimeError('Failed to add task to cache.')