mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-12-25 16:38:55 +01:00
Backend side merge API
This commit is contained in:
parent
b6f1194c93
commit
15a1436c8b
@ -13,7 +13,7 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||
from easydiffusion.utils import log
|
||||
|
||||
log.info(f'started in {app.SD_DIR}')
|
||||
@ -61,6 +61,11 @@ def init():
|
||||
def render(req: dict):
|
||||
return render_internal(req)
|
||||
|
||||
@server_api.post('/model/merge')
|
||||
def model_merge(req: dict):
|
||||
print(req)
|
||||
return model_merge_internal(req)
|
||||
|
||||
@server_api.get('/image/stream/{task_id:int}')
|
||||
def stream(task_id:int):
|
||||
return stream_internal(task_id)
|
||||
@ -181,6 +186,25 @@ def render_internal(req: dict):
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
try:
|
||||
from sdkit.train.merge_models import merge_models
|
||||
from easydiffusion.utils.save_utils import filename_regex
|
||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||
print('model_merge_internal')
|
||||
print(mergeReq)
|
||||
|
||||
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'),
|
||||
mergeReq.ratio,
|
||||
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)),
|
||||
mergeReq.use_fp16
|
||||
)
|
||||
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def stream_internal(task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
|
@ -40,6 +40,13 @@ class TaskData(BaseModel):
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
stream_image_progress: bool = False
|
||||
|
||||
class MergeRequest(BaseModel):
|
||||
model0: str = None
|
||||
model1: str = None
|
||||
ratio: float = None
|
||||
out_path: str = "mix"
|
||||
use_fp16 = True
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
seed: int
|
||||
|
@ -7,7 +7,7 @@ from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
|
||||
from sdkit.utils import save_images, save_dicts
|
||||
|
||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
filename_regex = re.compile('[^a-zA-Z0-9._-]')
|
||||
|
||||
# keep in sync with `ui/media/js/dnd.js`
|
||||
TASK_TEXT_MAPPING = {
|
||||
@ -76,4 +76,4 @@ def make_filename_callback(req: GenerateImageRequest, suffix=None):
|
||||
name = name if suffix is None else f'{name}_{suffix}'
|
||||
return name
|
||||
|
||||
return make_filename
|
||||
return make_filename
|
||||
|
Loading…
Reference in New Issue
Block a user