From 5eeef41d8cabcdc485e8a1095198c07f7ac92632 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 20 Dec 2022 15:16:47 +0530 Subject: [PATCH] Update to use the latest sdkit API --- ui/easydiffusion/model_manager.py | 18 ++++++------ ui/easydiffusion/renderer.py | 43 ++++++++++++---------------- ui/easydiffusion/task_manager.py | 4 +-- ui/easydiffusion/types.py | 21 +++++++++++++- ui/easydiffusion/utils/save_utils.py | 3 +- ui/server.py | 3 +- 6 files changed, 50 insertions(+), 42 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index eb8aa7fd..4c9fb1d1 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -5,14 +5,14 @@ from easydiffusion import app, device_manager from easydiffusion.types import TaskData from easydiffusion.utils import log -from sdkit.models import model_loader -from sdkit.types import Context +from sdkit import Context +from sdkit.models import load_model, unload_model KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { 'stable-diffusion': ['.ckpt', '.safetensors'], - 'vae': ['.vae.pt', '.ckpt'], - 'hypernetwork': ['.pt'], + 'vae': ['.vae.pt', '.ckpt', '.safetensors'], + 'hypernetwork': ['.pt', '.safetensors'], 'gfpgan': ['.pth'], 'realesrgan': ['.pth'], } @@ -44,13 +44,13 @@ def load_default_models(context: Context): set_vram_optimizations(context) # load mandatory models - model_loader.load_model(context, 'stable-diffusion') - model_loader.load_model(context, 'vae') - model_loader.load_model(context, 'hypernetwork') + load_model(context, 'stable-diffusion') + load_model(context, 'vae') + load_model(context, 'hypernetwork') def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: - model_loader.unload_model(context, model_type) + unload_model(context, model_type) def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) @@ -107,7 +107,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req - action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model + action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn(context, model_type) def resolve_model_paths(task_data: TaskData): diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 610e14fa..4b9847c0 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -3,12 +3,13 @@ import time import json from easydiffusion import device_manager -from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop +from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest from easydiffusion.utils import get_printable_request, save_images_to_disk, log -from sdkit import model_loader, image_generator, filters as image_filters -from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images -from sdkit.types import Context, GenerateImageRequest, FilterImageRequest +from sdkit import Context +from sdkit.generate import generate_images +from sdkit.filter import apply_filters +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc context = Context() # thread-local ''' @@ -30,7 +31,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu log.info(f'request: {get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') - images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) res = res.json() @@ -39,22 +40,22 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu return res -def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) - filtered_images = apply_filters(task_data, images, user_stopped) +def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + filtered_images = filter_images(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: save_images_to_disk(images, filtered_images, req, task_data) return filtered_images if task_data.show_only_filtered_image else images + filtered_images -def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): +def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): context.temp_images.clear() - image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) + callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=context, req=req) + images = generate_images(context, callback=callback, **req.dict()) user_stopped = False except UserInitiatedStop: images = [] @@ -63,27 +64,19 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: images = latent_samples_to_images(context, context.partial_x_samples) context.partial_x_samples = None finally: - model_loader.gc(context) + gc(context) return images, user_stopped -def apply_filters(task_data: TaskData, images: list, user_stopped): +def filter_images(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images - filters = [] - if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan') - if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan') + filters_to_apply = [] + if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') + if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan') - filtered_images = [] - for img in images: - filter_req = FilterImageRequest() - filter_req.init_image = img - - filtered_image = image_filters.apply(context, filters, filter_req) - filtered_images.append(filtered_image) - - return filtered_images + return apply_filters(context, filters_to_apply, images) def construct_response(images: list, task_data: TaskData, base_seed: int): return [ diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 19638715..3a764137 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -14,11 +14,9 @@ import queue, threading, time, weakref from typing import Any, Hashable from easydiffusion import device_manager -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.utils import log -from sdkit.types import GenerateImageRequest - THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 3a748431..2a10d521 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,6 +1,25 @@ from pydantic import BaseModel +from typing import Any -from sdkit.types import GenerateImageRequest +class GenerateImageRequest(BaseModel): + prompt: str = "" + negative_prompt: str = "" + + seed: int = 42 + width: int = 512 + height: int = 512 + + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + + init_image: Any = None + init_image_mask: Any = None + prompt_strength: float = 0.8 + preserve_init_image_color_profile = False + + sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + hypernetwork_strength: float = 0 class TaskData(BaseModel): request_id: str = None diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index bb1d09c9..d7fa82a3 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -3,10 +3,9 @@ import time import base64 import re -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from sdkit.utils import save_images, save_dicts -from sdkit.types import GenerateImageRequest filename_regex = re.compile('[^a-zA-Z0-9]') diff --git a/ui/server.py b/ui/server.py index 8ee34fec..8d9a129e 100644 --- a/ui/server.py +++ b/ui/server.py @@ -13,9 +13,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from easydiffusion import app, model_manager, task_manager -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.utils import log -from sdkit.types import GenerateImageRequest log.info(f'started in {app.SD_DIR}') log.info(f'started at {datetime.datetime.now():%x %X}')