forked from extern/easydiffusion
Rename runtime2.py to renderer.py; Will remove the old runtime soon
This commit is contained in:
parent
096556d8c9
commit
1a5b6ef260
@ -15,7 +15,7 @@ from modules.types import Context, GenerateImageRequest
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
thread_data = Context()
|
||||
context = Context() # thread-local
|
||||
'''
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
'''
|
||||
@ -24,13 +24,13 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
|
||||
def init(device):
|
||||
'''
|
||||
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
'''
|
||||
thread_data.stop_processing = False
|
||||
thread_data.temp_images = {}
|
||||
thread_data.partial_x_samples = None
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
@ -61,21 +61,21 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
||||
|
||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
log.info(req.to_metadata())
|
||||
thread_data.temp_images.clear()
|
||||
context.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
|
||||
try:
|
||||
images = image_generator.make_images(context=thread_data, req=req)
|
||||
images = image_generator.make_images(context=context, req=req)
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
user_stopped = True
|
||||
if thread_data.partial_x_samples is not None:
|
||||
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
|
||||
thread_data.partial_x_samples = None
|
||||
if context.partial_x_samples is not None:
|
||||
images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
|
||||
context.partial_x_samples = None
|
||||
finally:
|
||||
model_loader.gc(thread_data)
|
||||
model_loader.gc(context)
|
||||
|
||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||
|
||||
@ -92,7 +92,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
for filter_fn in filters:
|
||||
img = filter_fn(thread_data, img)
|
||||
img = filter_fn(context, img)
|
||||
|
||||
filtered_images.append((img, seed, True))
|
||||
|
||||
@ -145,12 +145,12 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
def update_temp_img(x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||
img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0))
|
||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||
|
||||
del img
|
||||
|
||||
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
return partial_images
|
||||
@ -158,7 +158,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
def on_image_step(x_samples, i):
|
||||
nonlocal last_callback_time
|
||||
|
||||
thread_data.partial_x_samples = x_samples
|
||||
context.partial_x_samples = x_samples
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
@ -171,7 +171,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
|
||||
step_callback()
|
||||
|
||||
if thread_data.stop_processing:
|
||||
if context.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
return on_image_step
|
@ -186,9 +186,9 @@ class SessionState():
|
||||
return True
|
||||
|
||||
def thread_get_next_task():
|
||||
from sd_internal import runtime2
|
||||
from sd_internal import renderer
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
|
||||
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
@ -196,7 +196,7 @@ def thread_get_next_task():
|
||||
task = None
|
||||
try: # Select a render task.
|
||||
for queued_task in tasks_queue:
|
||||
if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device:
|
||||
if queued_task.render_device and renderer.context.device != queued_task.render_device:
|
||||
# Is asking for a specific render device.
|
||||
if is_alive(queued_task.render_device) > 0:
|
||||
continue # requested device alive, skip current one.
|
||||
@ -205,7 +205,7 @@ def thread_get_next_task():
|
||||
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1:
|
||||
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
|
||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||
task = queued_task
|
||||
@ -219,9 +219,9 @@ def thread_get_next_task():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from sd_internal import runtime2, model_manager
|
||||
from sd_internal import renderer, model_manager
|
||||
try:
|
||||
runtime2.init(device)
|
||||
renderer.init(device)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
@ -230,20 +230,20 @@ def thread_render(device):
|
||||
return
|
||||
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'device': runtime2.thread_data.device,
|
||||
'device_name': runtime2.thread_data.device_name,
|
||||
'device': renderer.context.device,
|
||||
'device_name': renderer.context.device_name,
|
||||
'alive': True
|
||||
}
|
||||
|
||||
model_manager.load_default_models(runtime2.thread_data)
|
||||
model_manager.load_default_models(renderer.context)
|
||||
current_state = ServerStates.Online
|
||||
|
||||
while True:
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
||||
model_manager.unload_all(runtime2.thread_data)
|
||||
log.info(f'Shutting down thread for device {renderer.context.device}')
|
||||
model_manager.unload_all(renderer.context)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
@ -263,14 +263,14 @@ def thread_render(device):
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime2.thread_data.stop_processing = True
|
||||
renderer.context.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
@ -278,10 +278,10 @@ def thread_render(device):
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
|
||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
@ -299,7 +299,7 @@ def thread_render(device):
|
||||
elif task.error is not None:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||
else:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
|
||||
current_state = ServerStates.Online
|
||||
|
||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||
|
Loading…
Reference in New Issue
Block a user