Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable

This commit is contained in:
cmdr2
2022-12-11 18:16:29 +05:30
parent d03eed3859
commit 6ce6dc3ff6
11 changed files with 115 additions and 305 deletions

View File

@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager
from sd_internal import TaskData
from modules.types import GenerateImageRequest
log = logging.getLogger()
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
server_api = FastAPI()
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
listen_port: int = None
test_sd2: bool = None
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
# don't log certain requests
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
@ -137,20 +126,18 @@ def ping(session_id:str=None):
return JSONResponse(response, headers=NOCACHE_HEADERS)
@server_api.post('/render')
def render(req : task_manager.ImageRequest):
def render(req: dict):
try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
# separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
# resolve the model paths to use
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
# enqueue the task
new_task = task_manager.render(req)
new_task = task_manager.render(render_req, task_data)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),