forked from extern/easydiffusion
Rename runtime2.py to renderer.py; Will remove the old runtime soon
This commit is contained in:
parent
096556d8c9
commit
1a5b6ef260
@ -15,7 +15,7 @@ from modules.types import Context, GenerateImageRequest
|
|||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
thread_data = Context()
|
context = Context() # thread-local
|
||||||
'''
|
'''
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
'''
|
'''
|
||||||
@ -24,13 +24,13 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
|
|||||||
|
|
||||||
def init(device):
|
def init(device):
|
||||||
'''
|
'''
|
||||||
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||||
'''
|
'''
|
||||||
thread_data.stop_processing = False
|
context.stop_processing = False
|
||||||
thread_data.temp_images = {}
|
context.temp_images = {}
|
||||||
thread_data.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(context, device)
|
||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
@ -61,21 +61,21 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
|||||||
|
|
||||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
log.info(req.to_metadata())
|
log.info(req.to_metadata())
|
||||||
thread_data.temp_images.clear()
|
context.temp_images.clear()
|
||||||
|
|
||||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
images = image_generator.make_images(context=thread_data, req=req)
|
images = image_generator.make_images(context=context, req=req)
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
images = []
|
images = []
|
||||||
user_stopped = True
|
user_stopped = True
|
||||||
if thread_data.partial_x_samples is not None:
|
if context.partial_x_samples is not None:
|
||||||
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
|
images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
|
||||||
thread_data.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
finally:
|
finally:
|
||||||
model_loader.gc(thread_data)
|
model_loader.gc(context)
|
||||||
|
|
||||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil
|
|||||||
filtered_images = []
|
filtered_images = []
|
||||||
for img, seed, _ in images:
|
for img, seed, _ in images:
|
||||||
for filter_fn in filters:
|
for filter_fn in filters:
|
||||||
img = filter_fn(thread_data, img)
|
img = filter_fn(context, img)
|
||||||
|
|
||||||
filtered_images.append((img, seed, True))
|
filtered_images.append((img, seed, True))
|
||||||
|
|
||||||
@ -145,12 +145,12 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
def update_temp_img(x_samples, task_temp_images: list):
|
def update_temp_img(x_samples, task_temp_images: list):
|
||||||
partial_images = []
|
partial_images = []
|
||||||
for i in range(req.num_outputs):
|
for i in range(req.num_outputs):
|
||||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0))
|
||||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||||
|
|
||||||
del img
|
del img
|
||||||
|
|
||||||
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
|
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||||
task_temp_images[i] = buf
|
task_temp_images[i] = buf
|
||||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||||
return partial_images
|
return partial_images
|
||||||
@ -158,7 +158,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
def on_image_step(x_samples, i):
|
def on_image_step(x_samples, i):
|
||||||
nonlocal last_callback_time
|
nonlocal last_callback_time
|
||||||
|
|
||||||
thread_data.partial_x_samples = x_samples
|
context.partial_x_samples = x_samples
|
||||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||||
last_callback_time = time.time()
|
last_callback_time = time.time()
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
|
|
||||||
step_callback()
|
step_callback()
|
||||||
|
|
||||||
if thread_data.stop_processing:
|
if context.stop_processing:
|
||||||
raise UserInitiatedStop("User requested that we stop processing")
|
raise UserInitiatedStop("User requested that we stop processing")
|
||||||
|
|
||||||
return on_image_step
|
return on_image_step
|
@ -186,9 +186,9 @@ class SessionState():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def thread_get_next_task():
|
def thread_get_next_task():
|
||||||
from sd_internal import runtime2
|
from sd_internal import renderer
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
|
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
||||||
return None
|
return None
|
||||||
if len(tasks_queue) <= 0:
|
if len(tasks_queue) <= 0:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
@ -196,7 +196,7 @@ def thread_get_next_task():
|
|||||||
task = None
|
task = None
|
||||||
try: # Select a render task.
|
try: # Select a render task.
|
||||||
for queued_task in tasks_queue:
|
for queued_task in tasks_queue:
|
||||||
if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device:
|
if queued_task.render_device and renderer.context.device != queued_task.render_device:
|
||||||
# Is asking for a specific render device.
|
# Is asking for a specific render device.
|
||||||
if is_alive(queued_task.render_device) > 0:
|
if is_alive(queued_task.render_device) > 0:
|
||||||
continue # requested device alive, skip current one.
|
continue # requested device alive, skip current one.
|
||||||
@ -205,7 +205,7 @@ def thread_get_next_task():
|
|||||||
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
||||||
task = queued_task
|
task = queued_task
|
||||||
break
|
break
|
||||||
if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1:
|
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
|
||||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||||
task = queued_task
|
task = queued_task
|
||||||
@ -219,9 +219,9 @@ def thread_get_next_task():
|
|||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error
|
global current_state, current_state_error
|
||||||
|
|
||||||
from sd_internal import runtime2, model_manager
|
from sd_internal import renderer, model_manager
|
||||||
try:
|
try:
|
||||||
runtime2.init(device)
|
renderer.init(device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
@ -230,20 +230,20 @@ def thread_render(device):
|
|||||||
return
|
return
|
||||||
|
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
'device': runtime2.thread_data.device,
|
'device': renderer.context.device,
|
||||||
'device_name': runtime2.thread_data.device_name,
|
'device_name': renderer.context.device_name,
|
||||||
'alive': True
|
'alive': True
|
||||||
}
|
}
|
||||||
|
|
||||||
model_manager.load_default_models(runtime2.thread_data)
|
model_manager.load_default_models(renderer.context)
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
session_cache.clean()
|
session_cache.clean()
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||||
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
log.info(f'Shutting down thread for device {renderer.context.device}')
|
||||||
model_manager.unload_all(runtime2.thread_data)
|
model_manager.unload_all(renderer.context)
|
||||||
return
|
return
|
||||||
if isinstance(current_state_error, SystemExit):
|
if isinstance(current_state_error, SystemExit):
|
||||||
current_state = ServerStates.Unavailable
|
current_state = ServerStates.Unavailable
|
||||||
@ -263,14 +263,14 @@ def thread_render(device):
|
|||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
def step_callback():
|
def step_callback():
|
||||||
global current_state_error
|
global current_state_error
|
||||||
|
|
||||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||||
runtime2.thread_data.stop_processing = True
|
renderer.context.stop_processing = True
|
||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
@ -278,10 +278,10 @@ def thread_render(device):
|
|||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
model_manager.resolve_model_paths(task.task_data)
|
model_manager.resolve_model_paths(task.task_data)
|
||||||
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
|
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||||
@ -299,7 +299,7 @@ def thread_render(device):
|
|||||||
elif task.error is not None:
|
elif task.error is not None:
|
||||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||||
else:
|
else:
|
||||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user