"""task_manager.py: manage tasks dispatching and render threads. Notes: render_threads should be the only hard reference held by the manager to the threads. Use weak_thread_data to store all other data using weak keys. This will allow for garbage collection after the thread dies. """ import json import traceback import logging TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref from typing import Any, Hashable from sd_internal import TaskData, device_manager from diffusionkit.types import GenerateImageRequest log = logging.getLogger() THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. class SymbolClass(type): # Print nicely formatted Symbol names. def __repr__(self): return self.__qualname__ def __str__(self): return self.__name__ class Symbol(metaclass=SymbolClass): pass class ServerStates: class Init(Symbol): pass class LoadingModel(Symbol): pass class Online(Symbol): pass class Rendering(Symbol): pass class Unavailable(Symbol): pass class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: GenerateImageRequest, task_data: TaskData): task_data.request_id = id(self) self.render_request: GenerateImageRequest = req # Initial Request self.task_data: TaskData = task_data self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments async def read_buffer_generator(self): try: while not self.buffer_queue.empty(): res = self.buffer_queue.get(block=False) self.buffer_queue.task_done() yield res except queue.Empty as e: yield @property def status(self): if self.lock.locked(): return 'running' if isinstance(self.error, StopAsyncIteration): return 'stopped' if self.error: return 'error' if not self.buffer_queue.empty(): return 'buffer' if self.response: return 'completed' return 'pending' @property def is_pending(self): return bool(not self.response and not self.error) # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache(): def __init__(self): self._base = dict() self._lock: threading.Lock = threading.Lock() def _get_ttl_time(self, ttl: int) -> int: return int(time.time()) + ttl def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp def clean(self) -> None: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) try: # Create a list of expired keys to delete to_delete = [] for key in self._base: ttl, _ = self._base[key] if self._is_expired(ttl): to_delete.append(key) # Remove Items for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): log.debug(f'RenderTask {key} expired. Data removed.') elif isinstance(val, SessionState): log.debug(f'Session {key} expired. Data removed.') else: log.debug(f'Key {key} expired. Data removed.') del self._base[key] finally: self._lock.release() def clear(self) -> None: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) try: self._base.clear() finally: self._lock.release() def delete(self, key: Hashable) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) try: if key not in self._base: return False del self._base[key] return True finally: self._lock.release() def keep(self, key: Hashable, ttl: int) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) try: if key in self._base: _, value = self._base.get(key) self._base[key] = (self._get_ttl_time(ttl), value) return True return False finally: self._lock.release() def put(self, key: Hashable, value: Any, ttl: int) -> bool: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) try: self._base[key] = ( self._get_ttl_time(ttl), value ) except Exception as e: log.error(traceback.format_exc()) return False else: return True finally: self._lock.release() def tryGet(self, key: Hashable) -> Any: if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): log.debug(f'Session {key} expired. Discarding data.') del self._base[key] return None return value finally: self._lock.release() manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init current_state_error:Exception = None tasks_queue = [] session_cache = DataCache() task_cache = DataCache() weak_thread_data = weakref.WeakKeyDictionary() idle_event: threading.Event = threading.Event() class SessionState(): def __init__(self, id: str): self._id = id self._tasks_ids = [] @property def id(self): return self._id @property def tasks(self): tasks = [] for task_id in self._tasks_ids: task = task_cache.tryGet(task_id) if task: tasks.append(task) return tasks def put(self, task, ttl=TASK_TTL): task_id = id(task) self._tasks_ids.append(task_id) if not task_cache.put(task_id, task, ttl): return False while len(self._tasks_ids) > len(render_threads) * 2: self._tasks_ids.pop(0) return True def thread_get_next_task(): from sd_internal import renderer if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() return None task = None try: # Select a render task. for queued_task in tasks_queue: if queued_task.render_device and renderer.context.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. else: # Requested device is not active, return error to UI. queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task break if task is not None: del tasks_queue[tasks_queue.index(task)] return task finally: manager_lock.release() def thread_render(device): global current_state, current_state_error from sd_internal import renderer, model_manager try: renderer.init(device) except Exception as e: log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } return weak_thread_data[threading.current_thread()] = { 'device': renderer.context.device, 'device_name': renderer.context.device_name, 'alive': True } current_state = ServerStates.LoadingModel model_manager.load_default_models(renderer.context) current_state = ServerStates.Online while True: session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: log.info(f'Shutting down thread for device {renderer.context.device}') model_manager.unload_all(renderer.context) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable return task = thread_get_next_task() if task is None: idle_event.clear() idle_event.wait(timeout=1) continue if task.error is not None: log.error(task.error) task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue if current_state_error: task.error = current_state_error task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): renderer.context.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) model_manager.set_vram_optimizations(renderer.context, task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = e task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) continue finally: # Task completed task.lock.release() task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') elif task.error is not None: log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') else: log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): # By calling keep before tryGet, wont discard if was expired. if update_ttl and not task_cache.keep(task_id, TASK_TTL): # Failed to keep task, already gone. return None return task_cache.tryGet(task_id) def get_cached_session(session_id:str, update_ttl:bool=False): if update_ttl: session_cache.keep(session_id, TASK_TTL) session = session_cache.tryGet(session_id) if not session: session = SessionState(session_id) session_cache.put(session_id, session, TASK_TTL) return session def get_devices(): devices = { 'all': {}, 'active': {}, } def get_device_info(device): if device == 'cpu': return {'name': device_manager.get_processor_name()} mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) return { 'name': torch.cuda.get_device_name(device), 'mem_free': mem_free, 'mem_total': mem_total, } # list the compatible devices gpu_count = torch.cuda.device_count() for device in range(gpu_count): device = f'cuda:{device}' if not device_manager.is_device_compatible(device): continue devices['all'].update({device: get_device_info(device)}) devices['all'].update({'cpu': get_device_info('cpu')}) # list the activated devices if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED) try: for rthread in render_threads: if not rthread.is_alive(): continue weak_data = weak_thread_data.get(rthread) if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data: continue device = weak_data['device'] devices['active'].update({device: get_device_info(device)}) finally: manager_lock.release() return devices def is_alive(device=None): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) nbr_alive = 0 try: for rthread in render_threads: if device is not None: weak_data = weak_thread_data.get(rthread) if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: continue thread_device = weak_data['device'] if thread_device != device: continue if rthread.is_alive(): nbr_alive += 1 return nbr_alive finally: manager_lock.release() def start_render_thread(device): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) log.info(f'Start new Rendering Thread on device: {device}') try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True rthread.name = THREAD_NAME_PREFIX + device rthread.start() render_threads.append(rthread) finally: manager_lock.release() timeout = DEVICE_START_TIMEOUT while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: return False timeout -= 1 time.sleep(1) return True def stop_render_thread(device): try: device_manager.validate_device_id(device, log_prefix='stop_render_thread') except: log.error(traceback.format_exc()) return False if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) log.info(f'Stopping Rendering Thread on device: {device}') try: thread_to_remove = None for rthread in render_threads: weak_data = weak_thread_data.get(rthread) if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: continue thread_device = weak_data['device'] if thread_device == device: weak_data['alive'] = False thread_to_remove = rthread break if thread_to_remove is not None: render_threads.remove(rthread) return True finally: manager_lock.release() return False def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) log.debug(f'devices_to_start: {devices_to_start}') log.debug(f'devices_to_stop: {devices_to_stop}') for device in devices_to_stop: if is_alive(device) <= 0: log.debug(f'{device} is not alive') continue if not stop_render_thread(device): log.warn(f'{device} could not stop render thread') for device in devices_to_start: if is_alive(device) >= 1: log.debug(f'{device} already registered.') continue if not start_render_thread(device): log.warn(f'{device} failed to start.') if is_alive() <= 0: # No running devices, probably invalid user config. raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') log.debug(f"active devices: {get_devices()['active']}") def shutdown_event(): # Signal render thread to close on shutdown global current_state_error current_state_error = SystemExit('Application shutting down.') def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') # Alive, check if task in cache session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): try: tasks_queue.append(new_task) idle_event.set() return new_task finally: manager_lock.release() raise RuntimeError('Failed to add task to cache.')