forked from extern/easydiffusion
Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable
This commit is contained in:
31
ui/server.py
31
ui/server.py
@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sd_internal import app, model_manager, task_manager
|
||||
from sd_internal import TaskData
|
||||
from modules.types import GenerateImageRequest
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||
|
||||
server_api = FastAPI()
|
||||
|
||||
# don't show access log entries for URLs that start with the given prefix
|
||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
|
||||
listen_port: int = None
|
||||
test_sd2: bool = None
|
||||
|
||||
class LogSuppressFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
path = record.getMessage()
|
||||
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
||||
if path.find(prefix) != -1:
|
||||
return False
|
||||
return True
|
||||
|
||||
# don't log certain requests
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
||||
|
||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||
@ -137,20 +126,18 @@ def ping(session_id:str=None):
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
@server_api.post('/render')
|
||||
def render(req : task_manager.ImageRequest):
|
||||
def render(req: dict):
|
||||
try:
|
||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
# separate out the request data into rendering and task-specific data
|
||||
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||
task_data: TaskData = TaskData.parse_obj(req)
|
||||
|
||||
# resolve the model paths to use
|
||||
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
||||
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
||||
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
|
||||
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
|
||||
app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(req)
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
|
Reference in New Issue
Block a user