forked from extern/easydiffusion
Update to use the latest sdkit API
This commit is contained in:
parent
47e3884994
commit
5eeef41d8c
@ -5,14 +5,14 @@ from easydiffusion import app, device_manager
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit.models import model_loader
|
||||
from sdkit.types import Context
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model
|
||||
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
MODEL_EXTENSIONS = {
|
||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||
'vae': ['.vae.pt', '.ckpt'],
|
||||
'hypernetwork': ['.pt'],
|
||||
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
|
||||
'hypernetwork': ['.pt', '.safetensors'],
|
||||
'gfpgan': ['.pth'],
|
||||
'realesrgan': ['.pth'],
|
||||
}
|
||||
@ -44,13 +44,13 @@ def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_model(context, 'stable-diffusion')
|
||||
model_loader.load_model(context, 'vae')
|
||||
model_loader.load_model(context, 'hypernetwork')
|
||||
load_model(context, 'stable-diffusion')
|
||||
load_model(context, 'vae')
|
||||
load_model(context, 'hypernetwork')
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_loader.unload_model(context, model_type)
|
||||
unload_model(context, model_type)
|
||||
|
||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
@ -107,7 +107,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
action_fn(context, model_type)
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
|
@ -3,12 +3,13 @@ import time
|
||||
import json
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
|
||||
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
||||
|
||||
from sdkit import model_loader, image_generator, filters as image_filters
|
||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images
|
||||
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
|
||||
from sdkit import Context
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
||||
|
||||
context = Context() # thread-local
|
||||
'''
|
||||
@ -30,7 +31,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
||||
log.info(f'request: {get_printable_request(req)}')
|
||||
log.info(f'task data: {task_data.dict()}')
|
||||
|
||||
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
|
||||
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
|
||||
res = res.json()
|
||||
@ -39,22 +40,22 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
||||
|
||||
return res
|
||||
|
||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
filtered_images = apply_filters(task_data, images, user_stopped)
|
||||
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
filtered_images = filter_images(task_data, images, user_stopped)
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
save_images_to_disk(images, filtered_images, req, task_data)
|
||||
|
||||
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
||||
|
||||
def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
context.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
|
||||
try:
|
||||
images = image_generator.make_images(context=context, req=req)
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
@ -63,27 +64,19 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue:
|
||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||
context.partial_x_samples = None
|
||||
finally:
|
||||
model_loader.gc(context)
|
||||
gc(context)
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
def apply_filters(task_data: TaskData, images: list, user_stopped):
|
||||
def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan')
|
||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan')
|
||||
filters_to_apply = []
|
||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
|
||||
|
||||
filtered_images = []
|
||||
for img in images:
|
||||
filter_req = FilterImageRequest()
|
||||
filter_req.init_image = img
|
||||
|
||||
filtered_image = image_filters.apply(context, filters, filter_req)
|
||||
filtered_images.append(filtered_image)
|
||||
|
||||
return filtered_images
|
||||
return apply_filters(context, filters_to_apply, images)
|
||||
|
||||
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
||||
return [
|
||||
|
@ -14,11 +14,9 @@ import queue, threading, time, weakref
|
||||
from typing import Any, Hashable
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit.types import GenerateImageRequest
|
||||
|
||||
THREAD_NAME_PREFIX = ''
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
|
@ -1,6 +1,25 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
from sdkit.types import GenerateImageRequest
|
||||
class GenerateImageRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
|
||||
seed: int = 42
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
|
||||
num_outputs: int = 1
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 7.5
|
||||
|
||||
init_image: Any = None
|
||||
init_image_mask: Any = None
|
||||
prompt_strength: float = 0.8
|
||||
preserve_init_image_color_profile = False
|
||||
|
||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
hypernetwork_strength: float = 0
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
|
@ -3,10 +3,9 @@ import time
|
||||
import base64
|
||||
import re
|
||||
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
|
||||
from sdkit.utils import save_images, save_dicts
|
||||
from sdkit.types import GenerateImageRequest
|
||||
|
||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
|
||||
|
@ -13,9 +13,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
from easydiffusion.utils import log
|
||||
from sdkit.types import GenerateImageRequest
|
||||
|
||||
log.info(f'started in {app.SD_DIR}')
|
||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||
|
Loading…
Reference in New Issue
Block a user