mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-20 08:10:46 +02:00
commit
fe89d487f6
@ -13,7 +13,7 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from easydiffusion import app, model_manager, task_manager
|
from easydiffusion import app, model_manager, task_manager
|
||||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
log.info(f'started in {app.SD_DIR}')
|
log.info(f'started in {app.SD_DIR}')
|
||||||
@ -61,6 +61,11 @@ def init():
|
|||||||
def render(req: dict):
|
def render(req: dict):
|
||||||
return render_internal(req)
|
return render_internal(req)
|
||||||
|
|
||||||
|
@server_api.post('/model/merge')
|
||||||
|
def model_merge(req: dict):
|
||||||
|
print(req)
|
||||||
|
return model_merge_internal(req)
|
||||||
|
|
||||||
@server_api.get('/image/stream/{task_id:int}')
|
@server_api.get('/image/stream/{task_id:int}')
|
||||||
def stream(task_id:int):
|
def stream(task_id:int):
|
||||||
return stream_internal(task_id)
|
return stream_internal(task_id)
|
||||||
@ -181,6 +186,23 @@ def render_internal(req: dict):
|
|||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
def model_merge_internal(req: dict):
|
||||||
|
try:
|
||||||
|
from sdkit.train import merge_models
|
||||||
|
from easydiffusion.utils.save_utils import filename_regex
|
||||||
|
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||||
|
|
||||||
|
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'),
|
||||||
|
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'),
|
||||||
|
mergeReq.ratio,
|
||||||
|
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)),
|
||||||
|
mergeReq.use_fp16
|
||||||
|
)
|
||||||
|
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
def stream_internal(task_id:int):
|
def stream_internal(task_id:int):
|
||||||
#TODO Move to WebSockets ??
|
#TODO Move to WebSockets ??
|
||||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||||
|
@ -41,6 +41,13 @@ class TaskData(BaseModel):
|
|||||||
metadata_output_format: str = "txt" # or "json"
|
metadata_output_format: str = "txt" # or "json"
|
||||||
stream_image_progress: bool = False
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
|
class MergeRequest(BaseModel):
|
||||||
|
model0: str = None
|
||||||
|
model1: str = None
|
||||||
|
ratio: float = None
|
||||||
|
out_path: str = "mix"
|
||||||
|
use_fp16 = True
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
data: str # base64
|
data: str # base64
|
||||||
seed: int
|
seed: int
|
||||||
|
@ -7,7 +7,7 @@ from easydiffusion.types import TaskData, GenerateImageRequest
|
|||||||
|
|
||||||
from sdkit.utils import save_images, save_dicts
|
from sdkit.utils import save_images, save_dicts
|
||||||
|
|
||||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
filename_regex = re.compile('[^a-zA-Z0-9._-]')
|
||||||
|
|
||||||
# keep in sync with `ui/media/js/dnd.js`
|
# keep in sync with `ui/media/js/dnd.js`
|
||||||
TASK_TEXT_MAPPING = {
|
TASK_TEXT_MAPPING = {
|
||||||
|
Loading…
Reference in New Issue
Block a user