forked from extern/easydiffusion
Update to use the latest sdkit API
This commit is contained in:
parent
47e3884994
commit
5eeef41d8c
@ -5,14 +5,14 @@ from easydiffusion import app, device_manager
|
|||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit.models import model_loader
|
from sdkit import Context
|
||||||
from sdkit.types import Context
|
from sdkit.models import load_model, unload_model
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||||
MODEL_EXTENSIONS = {
|
MODEL_EXTENSIONS = {
|
||||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||||
'vae': ['.vae.pt', '.ckpt'],
|
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
|
||||||
'hypernetwork': ['.pt'],
|
'hypernetwork': ['.pt', '.safetensors'],
|
||||||
'gfpgan': ['.pth'],
|
'gfpgan': ['.pth'],
|
||||||
'realesrgan': ['.pth'],
|
'realesrgan': ['.pth'],
|
||||||
}
|
}
|
||||||
@ -44,13 +44,13 @@ def load_default_models(context: Context):
|
|||||||
set_vram_optimizations(context)
|
set_vram_optimizations(context)
|
||||||
|
|
||||||
# load mandatory models
|
# load mandatory models
|
||||||
model_loader.load_model(context, 'stable-diffusion')
|
load_model(context, 'stable-diffusion')
|
||||||
model_loader.load_model(context, 'vae')
|
load_model(context, 'vae')
|
||||||
model_loader.load_model(context, 'hypernetwork')
|
load_model(context, 'hypernetwork')
|
||||||
|
|
||||||
def unload_all(context: Context):
|
def unload_all(context: Context):
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_loader.unload_model(context, model_type)
|
unload_model(context, model_type)
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
@ -107,7 +107,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
|||||||
for model_type, model_path_in_req in models_to_reload.items():
|
for model_type, model_path_in_req in models_to_reload.items():
|
||||||
context.model_paths[model_type] = model_path_in_req
|
context.model_paths[model_type] = model_path_in_req
|
||||||
|
|
||||||
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
|
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||||
action_fn(context, model_type)
|
action_fn(context, model_type)
|
||||||
|
|
||||||
def resolve_model_paths(task_data: TaskData):
|
def resolve_model_paths(task_data: TaskData):
|
||||||
|
@ -3,12 +3,13 @@ import time
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from easydiffusion import device_manager
|
from easydiffusion import device_manager
|
||||||
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
|
||||||
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
||||||
|
|
||||||
from sdkit import model_loader, image_generator, filters as image_filters
|
from sdkit import Context
|
||||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images
|
from sdkit.generate import generate_images
|
||||||
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
|
from sdkit.filter import apply_filters
|
||||||
|
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
||||||
|
|
||||||
context = Context() # thread-local
|
context = Context() # thread-local
|
||||||
'''
|
'''
|
||||||
@ -30,7 +31,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
|||||||
log.info(f'request: {get_printable_request(req)}')
|
log.info(f'request: {get_printable_request(req)}')
|
||||||
log.info(f'task data: {task_data.dict()}')
|
log.info(f'task data: {task_data.dict()}')
|
||||||
|
|
||||||
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
|
|
||||||
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
|
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
|
||||||
res = res.json()
|
res = res.json()
|
||||||
@ -39,22 +40,22 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||||
filtered_images = apply_filters(task_data, images, user_stopped)
|
filtered_images = filter_images(task_data, images, user_stopped)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
save_images_to_disk(images, filtered_images, req, task_data)
|
save_images_to_disk(images, filtered_images, req, task_data)
|
||||||
|
|
||||||
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
||||||
|
|
||||||
def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
context.temp_images.clear()
|
context.temp_images.clear()
|
||||||
|
|
||||||
image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
images = image_generator.make_images(context=context, req=req)
|
images = generate_images(context, callback=callback, **req.dict())
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
images = []
|
images = []
|
||||||
@ -63,27 +64,19 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue:
|
|||||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||||
context.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
finally:
|
finally:
|
||||||
model_loader.gc(context)
|
gc(context)
|
||||||
|
|
||||||
return images, user_stopped
|
return images, user_stopped
|
||||||
|
|
||||||
def apply_filters(task_data: TaskData, images: list, user_stopped):
|
def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||||
return images
|
return images
|
||||||
|
|
||||||
filters = []
|
filters_to_apply = []
|
||||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan')
|
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan')
|
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
|
||||||
|
|
||||||
filtered_images = []
|
return apply_filters(context, filters_to_apply, images)
|
||||||
for img in images:
|
|
||||||
filter_req = FilterImageRequest()
|
|
||||||
filter_req.init_image = img
|
|
||||||
|
|
||||||
filtered_image = image_filters.apply(context, filters, filter_req)
|
|
||||||
filtered_images.append(filtered_image)
|
|
||||||
|
|
||||||
return filtered_images
|
|
||||||
|
|
||||||
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
||||||
return [
|
return [
|
||||||
|
@ -14,11 +14,9 @@ import queue, threading, time, weakref
|
|||||||
from typing import Any, Hashable
|
from typing import Any, Hashable
|
||||||
|
|
||||||
from easydiffusion import device_manager
|
from easydiffusion import device_manager
|
||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit.types import GenerateImageRequest
|
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = ''
|
THREAD_NAME_PREFIX = ''
|
||||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||||
|
@ -1,6 +1,25 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sdkit.types import GenerateImageRequest
|
class GenerateImageRequest(BaseModel):
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: str = ""
|
||||||
|
|
||||||
|
seed: int = 42
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
|
||||||
|
num_outputs: int = 1
|
||||||
|
num_inference_steps: int = 50
|
||||||
|
guidance_scale: float = 7.5
|
||||||
|
|
||||||
|
init_image: Any = None
|
||||||
|
init_image_mask: Any = None
|
||||||
|
prompt_strength: float = 0.8
|
||||||
|
preserve_init_image_color_profile = False
|
||||||
|
|
||||||
|
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||||
|
hypernetwork_strength: float = 0
|
||||||
|
|
||||||
class TaskData(BaseModel):
|
class TaskData(BaseModel):
|
||||||
request_id: str = None
|
request_id: str = None
|
||||||
|
@ -3,10 +3,9 @@ import time
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||||
|
|
||||||
from sdkit.utils import save_images, save_dicts
|
from sdkit.utils import save_images, save_dicts
|
||||||
from sdkit.types import GenerateImageRequest
|
|
||||||
|
|
||||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||||
|
|
||||||
|
@ -13,9 +13,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from easydiffusion import app, model_manager, task_manager
|
from easydiffusion import app, model_manager, task_manager
|
||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
from sdkit.types import GenerateImageRequest
|
|
||||||
|
|
||||||
log.info(f'started in {app.SD_DIR}')
|
log.info(f'started in {app.SD_DIR}')
|
||||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||||
|
Loading…
Reference in New Issue
Block a user