forked from extern/easydiffusion
Simplify the runtime code
This commit is contained in:
parent
0aa7968503
commit
97919c7e87
@ -4,6 +4,8 @@ import picklescan.scanner
|
||||
import rich
|
||||
|
||||
from sd_internal import app
|
||||
from modules import model_loader
|
||||
from modules.types import Context
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@ -30,6 +32,18 @@ def init():
|
||||
make_model_folders()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
def load_default_models(context: Context):
|
||||
# init default model paths
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_model(context, 'stable-diffusion')
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_loader.unload_model(context, model_type)
|
||||
|
||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
|
@ -33,18 +33,6 @@ def init(device):
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
def destroy():
|
||||
for model_type in model_manager.KNOWN_MODEL_TYPES:
|
||||
model_loader.unload_model(thread_data, model_type)
|
||||
|
||||
def load_default_models():
|
||||
# init default model paths
|
||||
for model_type in model_manager.KNOWN_MODEL_TYPES:
|
||||
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
def reload_models_if_necessary(task_data: TaskData):
|
||||
model_paths_in_req = (
|
||||
('hypernetwork', task_data.use_hypernetwork_model),
|
||||
@ -67,14 +55,6 @@ def reload_models_if_necessary(task_data: TaskData):
|
||||
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
# resolve the model paths to use
|
||||
resolve_model_paths(task_data)
|
||||
|
||||
# convert init image to PIL.Image
|
||||
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
|
||||
|
||||
# generate
|
||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
@ -86,17 +66,12 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
||||
raise e
|
||||
|
||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
metadata = req.dict()
|
||||
del metadata['init_image']
|
||||
del metadata['init_image_mask']
|
||||
print(metadata)
|
||||
|
||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
|
||||
save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image)
|
||||
|
||||
res = Response(req, task_data, images=construct_response(images))
|
||||
res = res.json()
|
||||
@ -114,6 +89,7 @@ def resolve_model_paths(task_data: TaskData):
|
||||
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||
|
||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
log.info(req.to_metadata())
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
@ -125,13 +101,11 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem
|
||||
images = []
|
||||
user_stopped = True
|
||||
if thread_data.partial_x_samples is not None:
|
||||
for i in range(req.num_outputs):
|
||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||
|
||||
thread_data.partial_x_samples = None
|
||||
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
|
||||
thread_data.partial_x_samples = None
|
||||
finally:
|
||||
model_loader.gc(thread_data)
|
||||
|
||||
|
||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||
|
||||
return images, user_stopped
|
||||
|
@ -219,7 +219,7 @@ def thread_get_next_task():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from sd_internal import runtime2
|
||||
from sd_internal import runtime2, model_manager
|
||||
try:
|
||||
runtime2.init(device)
|
||||
except Exception as e:
|
||||
@ -235,7 +235,7 @@ def thread_render(device):
|
||||
'alive': True
|
||||
}
|
||||
|
||||
runtime2.load_default_models()
|
||||
model_manager.load_default_models(runtime2.thread_data)
|
||||
current_state = ServerStates.Online
|
||||
|
||||
while True:
|
||||
@ -243,7 +243,7 @@ def thread_render(device):
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
||||
runtime2.destroy()
|
||||
model_manager.unload_all(runtime2.thread_data)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
@ -280,6 +280,7 @@ def thread_render(device):
|
||||
runtime2.reload_models_if_necessary(task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
runtime2.resolve_model_paths(task.task_data)
|
||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
|
Loading…
Reference in New Issue
Block a user