Rename runtime2.py to renderer.py; Will remove the old runtime soon

This commit is contained in:
cmdr2 2022-12-11 20:21:25 +05:30
parent 096556d8c9
commit 1a5b6ef260
2 changed files with 33 additions and 33 deletions

View File

@ -15,7 +15,7 @@ from modules.types import Context, GenerateImageRequest
log = logging.getLogger() log = logging.getLogger()
thread_data = Context() context = Context() # thread-local
''' '''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
''' '''
@ -24,13 +24,13 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
def init(device): def init(device):
''' '''
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device Initializes the fields that will be bound to this runtime's context, and sets the current torch device
''' '''
thread_data.stop_processing = False context.stop_processing = False
thread_data.temp_images = {} context.temp_images = {}
thread_data.partial_x_samples = None context.partial_x_samples = None
device_manager.device_init(thread_data, device) device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
@ -61,21 +61,21 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
log.info(req.to_metadata()) log.info(req.to_metadata())
thread_data.temp_images.clear() context.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
try: try:
images = image_generator.make_images(context=thread_data, req=req) images = image_generator.make_images(context=context, req=req)
user_stopped = False user_stopped = False
except UserInitiatedStop: except UserInitiatedStop:
images = [] images = []
user_stopped = True user_stopped = True
if thread_data.partial_x_samples is not None: if context.partial_x_samples is not None:
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples) images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
thread_data.partial_x_samples = None context.partial_x_samples = None
finally: finally:
model_loader.gc(thread_data) model_loader.gc(context)
images = [(image, req.seed + i, False) for i, image in enumerate(images)] images = [(image, req.seed + i, False) for i, image in enumerate(images)]
@ -92,7 +92,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil
filtered_images = [] filtered_images = []
for img, seed, _ in images: for img, seed, _ in images:
for filter_fn in filters: for filter_fn in filters:
img = filter_fn(thread_data, img) img = filter_fn(context, img)
filtered_images.append((img, seed, True)) filtered_images.append((img, seed, True))
@ -145,12 +145,12 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
def update_temp_img(x_samples, task_temp_images: list): def update_temp_img(x_samples, task_temp_images: list):
partial_images = [] partial_images = []
for i in range(req.num_outputs): for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG') buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img del img
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
return partial_images return partial_images
@ -158,7 +158,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
def on_image_step(x_samples, i): def on_image_step(x_samples, i):
nonlocal last_callback_time nonlocal last_callback_time
thread_data.partial_x_samples = x_samples context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time() last_callback_time = time.time()
@ -171,7 +171,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
step_callback() step_callback()
if thread_data.stop_processing: if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
return on_image_step return on_image_step

View File

@ -186,9 +186,9 @@ class SessionState():
return True return True
def thread_get_next_task(): def thread_get_next_task():
from sd_internal import runtime2 from sd_internal import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.') log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -196,7 +196,7 @@ def thread_get_next_task():
task = None task = None
try: # Select a render task. try: # Select a render task.
for queued_task in tasks_queue: for queued_task in tasks_queue:
if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device: if queued_task.render_device and renderer.context.device != queued_task.render_device:
# Is asking for a specific render device. # Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0: if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
@ -205,7 +205,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + ' is not currently active.') queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
task = queued_task task = queued_task
break break
if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1: if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -219,9 +219,9 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from sd_internal import runtime2, model_manager from sd_internal import renderer, model_manager
try: try:
runtime2.init(device) renderer.init(device)
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
@ -230,20 +230,20 @@ def thread_render(device):
return return
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'device': runtime2.thread_data.device, 'device': renderer.context.device,
'device_name': runtime2.thread_data.device_name, 'device_name': renderer.context.device_name,
'alive': True 'alive': True
} }
model_manager.load_default_models(runtime2.thread_data) model_manager.load_default_models(renderer.context)
current_state = ServerStates.Online current_state = ServerStates.Online
while True: while True:
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']: if not weak_thread_data[threading.current_thread()]['alive']:
log.info(f'Shutting down thread for device {runtime2.thread_data.device}') log.info(f'Shutting down thread for device {renderer.context.device}')
model_manager.unload_all(runtime2.thread_data) model_manager.unload_all(renderer.context)
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
@ -263,14 +263,14 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
def step_callback(): def step_callback():
global current_state_error global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime2.thread_data.stop_processing = True renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
@ -278,10 +278,10 @@ def thread_render(device):
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data) model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
@ -299,7 +299,7 @@ def thread_render(device):
elif task.error is not None: elif task.error is not None:
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else: else:
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False): def get_cached_task(task_id:str, update_ttl:bool=False):