Simplify the runtime code

This commit is contained in:
cmdr2 2022-12-11 19:58:12 +05:30
parent 0aa7968503
commit 97919c7e87
3 changed files with 23 additions and 34 deletions

View File

@ -4,6 +4,8 @@ import picklescan.scanner
import rich
from sd_internal import app
from modules import model_loader
from modules.types import Context
log = logging.getLogger()
@ -30,6 +32,18 @@ def init():
make_model_folders()
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context):
# init default model paths
for model_type in KNOWN_MODEL_TYPES:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
# load mandatory models
model_loader.load_model(context, 'stable-diffusion')
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
model_loader.unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])

View File

@ -33,18 +33,6 @@ def init(device):
device_manager.device_init(thread_data, device)
def destroy():
for model_type in model_manager.KNOWN_MODEL_TYPES:
model_loader.unload_model(thread_data, model_type)
def load_default_models():
# init default model paths
for model_type in model_manager.KNOWN_MODEL_TYPES:
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
# load mandatory models
model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(task_data: TaskData):
model_paths_in_req = (
('hypernetwork', task_data.use_hypernetwork_model),
@ -67,14 +55,6 @@ def reload_models_if_necessary(task_data: TaskData):
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
# resolve the model paths to use
resolve_model_paths(task_data)
# convert init image to PIL.Image
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
# generate
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
except Exception as e:
log.error(traceback.format_exc())
@ -86,17 +66,12 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
raise e
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
print(metadata)
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
if task_data.save_to_disk_path is not None:
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image)
res = Response(req, task_data, images=construct_response(images))
res = res.json()
@ -114,6 +89,7 @@ def resolve_model_paths(task_data: TaskData):
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
log.info(req.to_metadata())
thread_data.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
@ -125,10 +101,8 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem
images = []
user_stopped = True
if thread_data.partial_x_samples is not None:
for i in range(req.num_outputs):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
thread_data.partial_x_samples = None
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
thread_data.partial_x_samples = None
finally:
model_loader.gc(thread_data)

View File

@ -219,7 +219,7 @@ def thread_get_next_task():
def thread_render(device):
global current_state, current_state_error
from sd_internal import runtime2
from sd_internal import runtime2, model_manager
try:
runtime2.init(device)
except Exception as e:
@ -235,7 +235,7 @@ def thread_render(device):
'alive': True
}
runtime2.load_default_models()
model_manager.load_default_models(runtime2.thread_data)
current_state = ServerStates.Online
while True:
@ -243,7 +243,7 @@ def thread_render(device):
task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']:
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
runtime2.destroy()
model_manager.unload_all(runtime2.thread_data)
return
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
@ -280,6 +280,7 @@ def thread_render(device):
runtime2.reload_models_if_necessary(task.task_data)
current_state = ServerStates.Rendering
runtime2.resolve_model_paths(task.task_data)
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)