Update to use the latest sdkit API

This commit is contained in:
cmdr2 2022-12-20 15:16:47 +05:30
parent 47e3884994
commit 5eeef41d8c
6 changed files with 50 additions and 42 deletions

View File

@ -5,14 +5,14 @@ from easydiffusion import app, device_manager
from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit.models import model_loader
from sdkit.types import Context
from sdkit import Context
from sdkit.models import load_model, unload_model
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt'],
'hypernetwork': ['.pt'],
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
'hypernetwork': ['.pt', '.safetensors'],
'gfpgan': ['.pth'],
'realesrgan': ['.pth'],
}
@ -44,13 +44,13 @@ def load_default_models(context: Context):
set_vram_optimizations(context)
# load mandatory models
model_loader.load_model(context, 'stable-diffusion')
model_loader.load_model(context, 'vae')
model_loader.load_model(context, 'hypernetwork')
load_model(context, 'stable-diffusion')
load_model(context, 'vae')
load_model(context, 'hypernetwork')
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
model_loader.unload_model(context, model_type)
unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
@ -107,7 +107,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
action_fn = unload_model if context.model_paths[model_type] is None else load_model
action_fn(context, model_type)
def resolve_model_paths(task_data: TaskData):

View File

@ -3,12 +3,13 @@ import time
import json
from easydiffusion import device_manager
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
from sdkit import model_loader, image_generator, filters as image_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
from sdkit import Context
from sdkit.generate import generate_images
from sdkit.filter import apply_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
context = Context() # thread-local
'''
@ -30,7 +31,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
log.info(f'request: {get_printable_request(req)}')
log.info(f'task data: {task_data.dict()}')
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
res = res.json()
@ -39,22 +40,22 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
return res
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
filtered_images = apply_filters(task_data, images, user_stopped)
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
filtered_images = filter_images(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data)
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
context.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = image_generator.make_images(context=context, req=req)
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
@ -63,27 +64,19 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue:
images = latent_samples_to_images(context, context.partial_x_samples)
context.partial_x_samples = None
finally:
model_loader.gc(context)
gc(context)
return images, user_stopped
def apply_filters(task_data: TaskData, images: list, user_stopped):
def filter_images(task_data: TaskData, images: list, user_stopped):
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images
filters = []
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan')
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan')
filters_to_apply = []
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
filtered_images = []
for img in images:
filter_req = FilterImageRequest()
filter_req.init_image = img
filtered_image = image_filters.apply(context, filters, filter_req)
filtered_images.append(filtered_image)
return filtered_images
return apply_filters(context, filters_to_apply, images)
def construct_response(images: list, task_data: TaskData, base_seed: int):
return [

View File

@ -14,11 +14,9 @@ import queue, threading, time, weakref
from typing import Any, Hashable
from easydiffusion import device_manager
from easydiffusion.types import TaskData
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
from sdkit.types import GenerateImageRequest
THREAD_NAME_PREFIX = ''
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.

View File

@ -1,6 +1,25 @@
from pydantic import BaseModel
from typing import Any
from sdkit.types import GenerateImageRequest
class GenerateImageRequest(BaseModel):
prompt: str = ""
negative_prompt: str = ""
seed: int = 42
width: int = 512
height: int = 512
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
init_image: Any = None
init_image_mask: Any = None
prompt_strength: float = 0.8
preserve_init_image_color_profile = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0
class TaskData(BaseModel):
request_id: str = None

View File

@ -3,10 +3,9 @@ import time
import base64
import re
from easydiffusion.types import TaskData
from easydiffusion.types import TaskData, GenerateImageRequest
from sdkit.utils import save_images, save_dicts
from sdkit.types import GenerateImageRequest
filename_regex = re.compile('[^a-zA-Z0-9]')

View File

@ -13,9 +13,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
from sdkit.types import GenerateImageRequest
log.info(f'started in {app.SD_DIR}')
log.info(f'started at {datetime.datetime.now():%x %X}')