forked from extern/easydiffusion
Simplify the runtime code
This commit is contained in:
parent
0aa7968503
commit
97919c7e87
@ -4,6 +4,8 @@ import picklescan.scanner
|
|||||||
import rich
|
import rich
|
||||||
|
|
||||||
from sd_internal import app
|
from sd_internal import app
|
||||||
|
from modules import model_loader
|
||||||
|
from modules.types import Context
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
@ -30,6 +32,18 @@ def init():
|
|||||||
make_model_folders()
|
make_model_folders()
|
||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
|
def load_default_models(context: Context):
|
||||||
|
# init default model paths
|
||||||
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
|
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||||
|
|
||||||
|
# load mandatory models
|
||||||
|
model_loader.load_model(context, 'stable-diffusion')
|
||||||
|
|
||||||
|
def unload_all(context: Context):
|
||||||
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
|
model_loader.unload_model(context, model_type)
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
|
@ -33,18 +33,6 @@ def init(device):
|
|||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(thread_data, device)
|
||||||
|
|
||||||
def destroy():
|
|
||||||
for model_type in model_manager.KNOWN_MODEL_TYPES:
|
|
||||||
model_loader.unload_model(thread_data, model_type)
|
|
||||||
|
|
||||||
def load_default_models():
|
|
||||||
# init default model paths
|
|
||||||
for model_type in model_manager.KNOWN_MODEL_TYPES:
|
|
||||||
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
|
|
||||||
|
|
||||||
# load mandatory models
|
|
||||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
|
||||||
|
|
||||||
def reload_models_if_necessary(task_data: TaskData):
|
def reload_models_if_necessary(task_data: TaskData):
|
||||||
model_paths_in_req = (
|
model_paths_in_req = (
|
||||||
('hypernetwork', task_data.use_hypernetwork_model),
|
('hypernetwork', task_data.use_hypernetwork_model),
|
||||||
@ -67,14 +55,6 @@ def reload_models_if_necessary(task_data: TaskData):
|
|||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
# resolve the model paths to use
|
|
||||||
resolve_model_paths(task_data)
|
|
||||||
|
|
||||||
# convert init image to PIL.Image
|
|
||||||
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
|
||||||
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
|
|
||||||
|
|
||||||
# generate
|
|
||||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
@ -86,17 +66,12 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
metadata = req.dict()
|
|
||||||
del metadata['init_image']
|
|
||||||
del metadata['init_image_mask']
|
|
||||||
print(metadata)
|
|
||||||
|
|
||||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||||
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||||
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
|
save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image)
|
||||||
|
|
||||||
res = Response(req, task_data, images=construct_response(images))
|
res = Response(req, task_data, images=construct_response(images))
|
||||||
res = res.json()
|
res = res.json()
|
||||||
@ -114,6 +89,7 @@ def resolve_model_paths(task_data: TaskData):
|
|||||||
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||||
|
|
||||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
|
log.info(req.to_metadata())
|
||||||
thread_data.temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
|
||||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
@ -125,9 +101,7 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem
|
|||||||
images = []
|
images = []
|
||||||
user_stopped = True
|
user_stopped = True
|
||||||
if thread_data.partial_x_samples is not None:
|
if thread_data.partial_x_samples is not None:
|
||||||
for i in range(req.num_outputs):
|
images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples)
|
||||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
|
||||||
|
|
||||||
thread_data.partial_x_samples = None
|
thread_data.partial_x_samples = None
|
||||||
finally:
|
finally:
|
||||||
model_loader.gc(thread_data)
|
model_loader.gc(thread_data)
|
||||||
|
@ -219,7 +219,7 @@ def thread_get_next_task():
|
|||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error
|
global current_state, current_state_error
|
||||||
|
|
||||||
from sd_internal import runtime2
|
from sd_internal import runtime2, model_manager
|
||||||
try:
|
try:
|
||||||
runtime2.init(device)
|
runtime2.init(device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -235,7 +235,7 @@ def thread_render(device):
|
|||||||
'alive': True
|
'alive': True
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime2.load_default_models()
|
model_manager.load_default_models(runtime2.thread_data)
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -243,7 +243,7 @@ def thread_render(device):
|
|||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||||
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
||||||
runtime2.destroy()
|
model_manager.unload_all(runtime2.thread_data)
|
||||||
return
|
return
|
||||||
if isinstance(current_state_error, SystemExit):
|
if isinstance(current_state_error, SystemExit):
|
||||||
current_state = ServerStates.Unavailable
|
current_state = ServerStates.Unavailable
|
||||||
@ -280,6 +280,7 @@ def thread_render(device):
|
|||||||
runtime2.reload_models_if_necessary(task.task_data)
|
runtime2.reload_models_if_necessary(task.task_data)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
|
runtime2.resolve_model_paths(task.task_data)
|
||||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
|
Loading…
Reference in New Issue
Block a user