Rename runtime2.py to renderer.py; Will remove the old runtime soon

This commit is contained in:
cmdr2 2022-12-11 20:21:25 +05:30
parent 096556d8c9
commit 1a5b6ef260
2 changed files with 33 additions and 33 deletions

View File

@ -15,7 +15,7 @@ from modules.types import Context, GenerateImageRequest
log = logging.getLogger()
thread_data = Context()
context = Context() # thread-local
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
@ -24,13 +24,13 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
def init(device):
'''
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
'''
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.partial_x_samples = None
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
device_manager.device_init(thread_data, device)
device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
@ -61,21 +61,21 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
log.info(req.to_metadata())
thread_data.temp_images.clear()
context.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = image_generator.make_images(context=thread_data, req=req)
images = image_generator.make_images(context=context, req=req)
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if thread_data.partial_x_samples is not None:
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
thread_data.partial_x_samples = None
if context.partial_x_samples is not None:
images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
context.partial_x_samples = None
finally:
model_loader.gc(thread_data)
model_loader.gc(context)
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
@ -92,7 +92,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil
filtered_images = []
for img, seed, _ in images:
for filter_fn in filters:
img = filter_fn(thread_data, img)
img = filter_fn(context, img)
filtered_images.append((img, seed, True))
@ -145,12 +145,12 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
return partial_images
@ -158,7 +158,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
def on_image_step(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
@ -171,7 +171,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
step_callback()
if thread_data.stop_processing:
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

View File

@ -186,9 +186,9 @@ class SessionState():
return True
def thread_get_next_task():
from sd_internal import runtime2
from sd_internal import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
return None
if len(tasks_queue) <= 0:
manager_lock.release()
@ -196,7 +196,7 @@ def thread_get_next_task():
task = None
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device:
if queued_task.render_device and renderer.context.device != queued_task.render_device:
# Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one.
@ -205,7 +205,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
task = queued_task
break
if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1:
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task
@ -219,9 +219,9 @@ def thread_get_next_task():
def thread_render(device):
global current_state, current_state_error
from sd_internal import runtime2, model_manager
from sd_internal import renderer, model_manager
try:
runtime2.init(device)
renderer.init(device)
except Exception as e:
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
@ -230,20 +230,20 @@ def thread_render(device):
return
weak_thread_data[threading.current_thread()] = {
'device': runtime2.thread_data.device,
'device_name': runtime2.thread_data.device_name,
'device': renderer.context.device,
'device_name': renderer.context.device_name,
'alive': True
}
model_manager.load_default_models(runtime2.thread_data)
model_manager.load_default_models(renderer.context)
current_state = ServerStates.Online
while True:
session_cache.clean()
task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']:
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
model_manager.unload_all(runtime2.thread_data)
log.info(f'Shutting down thread for device {renderer.context.device}')
model_manager.unload_all(renderer.context)
return
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
@ -263,14 +263,14 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
def step_callback():
global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime2.thread_data.stop_processing = True
renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
@ -278,10 +278,10 @@ def thread_render(device):
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
current_state = ServerStates.Rendering
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
@ -299,7 +299,7 @@ def thread_render(device):
elif task.error is not None:
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else:
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False):