easydiffusion/ui/sd_internal/task_manager.py

367 lines
15 KiB
Python
Raw Normal View History

import json
import traceback
TASK_TTL = 15 * 60 # Discard last session's task timeout
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel
from sd_internal import Request, Response
2022-10-17 03:41:39 +02:00
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates:
class Init(Symbol): pass
class LoadingModel(Symbol): pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except queue.Empty as e: yield
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
stream_image_progress: bool = False
2022-10-17 03:41:39 +02:00
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
2022-10-17 03:41:39 +02:00
self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
try:
# Create a list of expired keys to delete
to_delete = []
for key in self._base:
ttl, _ = self._base[key]
if self._is_expired(ttl):
to_delete.append(key)
# Remove Items
for key in to_delete:
del self._base[key]
print(f'Session {key} expired. Data removed.')
finally:
self._lock.release()
def clear(self) -> None:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear()
finally: self._lock.release()
def delete(self, key: Hashable) -> bool:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
try:
if key not in self._base:
return False
del self._base[key]
return True
finally:
self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
try:
if key in self._base:
_, value = self._base.get(key)
self._base[key] = (self._get_ttl_time(ttl), value)
return True
return False
finally:
self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
try:
self._base[key] = (
self._get_ttl_time(ttl), value
)
2022-10-15 10:08:17 +02:00
except Exception as e:
print(str(e))
print(traceback.format_exc())
return False
else:
return True
finally:
self._lock.release()
def tryGet(self, key: Hashable) -> Any:
2022-10-17 03:41:39 +02:00
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
2022-10-17 03:41:39 +02:00
del self._base[key]
return None
return value
finally:
self._lock.release()
2022-10-17 03:41:39 +02:00
manager_lock = threading.Lock()
render_threads = []
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
2022-10-17 03:41:39 +02:00
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = default_model_to_load
if file_path == current_model_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
2022-10-17 03:41:39 +02:00
runtime.thread_data.ckpt_file = file_path
runtime.load_model_ckpt()
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
2022-10-17 03:41:39 +02:00
def thread_render(device):
global current_state, current_state_error, current_model_path
from . import runtime
2022-10-17 03:41:39 +02:00
try:
runtime.device_init(device)
except:
print(traceback.format_exc())
return
preload_model()
2022-10-17 03:41:39 +02:00
current_state = ServerStates.Online
while True:
task_cache.clean()
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
task = None
2022-10-17 03:41:39 +02:00
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
time.sleep(1)
continue
if len(tasks_queue) <= 0:
manager_lock.release()
time.sleep(1)
continue
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
continue # Cuda Tasks
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
continue # CPU Tasks
if queued_task.request.use_face_correction and not runtime.is_first_cuda_device(runtime.thread_data.device):
continue #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
task = queued_task
break
if task is not None:
del tasks_queue[tasks_queue.index(task)]
finally:
manager_lock.release()
if task is None:
time.sleep(1)
continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
2022-10-17 03:41:39 +02:00
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
2022-10-17 03:41:39 +02:00
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel
2022-10-17 03:41:39 +02:00
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
print(traceback.format_exc())
continue
2022-10-17 03:41:39 +02:00
finally:
# Task completed
task.lock.release()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
2022-10-17 03:41:39 +02:00
def is_alive(name=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name and not rthread.name.endswith(name):
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
finally:
manager_lock.release()
2022-10-17 03:41:39 +02:00
def start_render_thread(device='auto'):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = 'Runner/' + device
rthread.start()
render_threads.append(rthread)
finally:
manager_lock.release()
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
2022-10-17 03:41:39 +02:00
if not is_alive(): # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked():
# Unstarted task pending, deny queueing more than one.
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
#
from . import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
2022-10-15 10:08:17 +02:00
if task_cache.put(r.session_id, new_task, TASK_TTL):
2022-10-17 03:41:39 +02:00
# Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
return new_task
finally:
manager_lock.release()
2022-10-15 10:08:17 +02:00
raise RuntimeError('Failed to add task to cache.')