From 15a1436c8ba142627c90cc61a4af0c1537117dfa Mon Sep 17 00:00:00 2001 From: JeLuF Date: Fri, 30 Dec 2022 10:07:23 +0100 Subject: [PATCH] Backend side merge API --- ui/easydiffusion/server.py | 26 +++++++++++++++++++++++++- ui/easydiffusion/types.py | 7 +++++++ ui/easydiffusion/utils/save_utils.py | 4 ++-- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 56535f1f..f74a7303 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -13,7 +13,7 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from easydiffusion import app, model_manager, task_manager -from easydiffusion.types import TaskData, GenerateImageRequest +from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest from easydiffusion.utils import log log.info(f'started in {app.SD_DIR}') @@ -61,6 +61,11 @@ def init(): def render(req: dict): return render_internal(req) + @server_api.post('/model/merge') + def model_merge(req: dict): + print(req) + return model_merge_internal(req) + @server_api.get('/image/stream/{task_id:int}') def stream(task_id:int): return stream_internal(task_id) @@ -181,6 +186,25 @@ def render_internal(req: dict): log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) +def model_merge_internal(req: dict): + try: + from sdkit.train.merge_models import merge_models + from easydiffusion.utils.save_utils import filename_regex + mergeReq: MergeRequest = MergeRequest.parse_obj(req) + print('model_merge_internal') + print(mergeReq) + + merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'), + model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'), + mergeReq.ratio, + os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)), + mergeReq.use_fp16 + ) + return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS) + except Exception as e: + log.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + def stream_internal(task_id:int): #TODO Move to WebSockets ?? task = task_manager.get_cached_task(task_id, update_ttl=True) diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 805c8683..4362045b 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -40,6 +40,13 @@ class TaskData(BaseModel): metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False +class MergeRequest(BaseModel): + model0: str = None + model1: str = None + ratio: float = None + out_path: str = "mix" + use_fp16 = True + class Image: data: str # base64 seed: int diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index b9ce8aba..926e72a4 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,7 +7,7 @@ from easydiffusion.types import TaskData, GenerateImageRequest from sdkit.utils import save_images, save_dicts -filename_regex = re.compile('[^a-zA-Z0-9]') +filename_regex = re.compile('[^a-zA-Z0-9._-]') # keep in sync with `ui/media/js/dnd.js` TASK_TEXT_MAPPING = { @@ -76,4 +76,4 @@ def make_filename_callback(req: GenerateImageRequest, suffix=None): name = name if suffix is None else f'{name}_{suffix}' return name - return make_filename \ No newline at end of file + return make_filename