easydiffusion/ui/server.py

516 lines
20 KiB
Python
Raw Normal View History

2022-09-09 17:04:32 +02:00
import json
import traceback
import sys
import os
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
2022-09-09 17:04:32 +02:00
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
2022-10-14 09:47:25 +02:00
TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
2022-10-14 09:47:25 +02:00
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
2022-10-14 09:47:25 +02:00
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from sd_internal import Request, Response
app = FastAPI()
modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# don't show access log entries for URLs that start with the given prefix
2022-10-14 09:47:25 +02:00
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
2022-10-14 09:47:25 +02:00
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
2022-10-14 09:47:25 +02:00
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates:
class Init(Symbol): pass
class LoadingModel(Symbol): pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
2022-09-14 18:59:42 +02:00
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
2022-09-05 13:55:25 +02:00
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
2022-09-14 18:59:42 +02:00
stream_image_progress: bool = False
2022-10-14 09:47:25 +02:00
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
for key in self._base:
ttl, _ = self._base[key]
if self._is_expired(ttl):
del self._base[key]
def clear(self) -> None:
self._base.clear()
def delete(self, key: Hashable) -> bool:
if key not in self._base:
return False
del self._base[key]
return True
def keep(self, key: Hashable, ttl: int) -> bool:
if key in self._base:
_, value = self._base.get(key)
self._base[key] = (self._get_ttl_time(ttl), value)
return True
return False
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
try:
self._base[key] = (
self._get_ttl_time(ttl), value
)
except Exception:
return False
return True
def tryGet(self, key: Hashable) -> Any:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
self.delete(key)
return None
return value
task_cache = TaskCache()
2022-09-09 17:04:32 +02:00
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
@app.get('/')
def read_root():
2022-10-14 09:47:25 +02:00
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = get_initial_model_to_load()
if file_path == current_model_path:
return
current_state = ServerStates.LoadingModel
try:
from sd_internal import runtime
2022-10-14 09:47:25 +02:00
runtime.load_model_ckpt(ckpt_to_use=file_path)
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
2022-10-14 09:47:25 +02:00
current_model_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
2022-10-14 09:47:25 +02:00
def thread_render():
global current_state, current_state_error
from sd_internal import runtime
current_state = ServerStates.Online
preload_model()
while True:
task_cache.clean()
task = None
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
current_state = ServerStates.Rendering
try:
task.lock.acquire(blocking=False)
res = runtime.mk_img(task.request)
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
# Start Rendering Thread
render_thread = threading.Thread(target=thread_render)
render_thread.daemon = True
render_thread.start()
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
2022-10-14 09:47:25 +02:00
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if current_state_error or not render_thread.is_alive(): # Render thread is dead.
return HTTPException(status_code=500, detail=str(current_state_error))
# Alive
response = {'status': str(current_state)}
if session_id:
task = task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
@app.post('/render')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one.
return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
#
from sd_internal import runtime
r = Request()
2022-09-14 18:59:42 +02:00
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision
2022-09-05 13:55:25 +02:00
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
2022-09-14 18:59:42 +02:00
r.stream_image_progress = req.stream_image_progress
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model)
2022-10-14 09:47:25 +02:00
if not req.stream_progress_updates:
r.stream_image_progress = False
2022-10-14 09:47:25 +02:00
new_task = RenderTask(r)
task_cache.put(r.session_id, new_task, TASK_TTL)
tasks_queue.put(new_task)
2022-10-14 09:47:25 +02:00
response = {
'status': str(current_state),
'queue': tasks_queue.qsize(),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
2022-10-14 09:47:25 +02:00
async def read_data_generator(data:queue.Queue, lock:threading.Lock):
try:
2022-10-14 09:47:25 +02:00
while not data.empty():
res = data.get(block=False)
data.task_done()
yield res
except queue.Empty as e: yield
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json')
2022-10-14 09:47:25 +02:00
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
if current_state == ServerStates.Online or current_state == ServerStates.Unavailable:
return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
global current_state_error
current_state_error = StopAsyncIteration()
return {'OK'}
2022-10-14 09:47:25 +02:00
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
2022-10-14 09:47:25 +02:00
@app.get('/image/tmp/{session_id}/{img_id:int}')
2022-09-14 18:59:42 +02:00
def get_image(session_id, img_id):
2022-10-14 09:47:25 +02:00
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
2022-09-14 18:59:42 +02:00
2022-09-09 17:04:32 +02:00
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w') as f:
f.write(config_json_str)
with open(config_bat_path, 'w') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
2022-10-14 09:42:43 +02:00
def getConfig(default_val={}):
2022-09-09 17:04:32 +02:00
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
2022-10-14 09:42:43 +02:00
return default_val
2022-09-09 17:04:32 +02:00
with open(config_json_path, 'r') as f:
return json.load(f)
2022-09-09 17:04:32 +02:00
except Exception as e:
2022-10-14 09:42:43 +02:00
print(str(e))
2022-09-09 17:04:32 +02:00
print(traceback.format_exc())
2022-10-14 09:42:43 +02:00
return default_val
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
2022-10-14 09:42:43 +02:00
print(str(e))
print(traceback.format_exc())
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
2022-10-14 09:42:43 +02:00
@app.get('/get')
def read_web_data(key:str=None):
if key is None: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return config
elif key == 'models':
return getModels()
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return {outpath}
else:
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
# don't log certain requests
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')