easydiffusion/ui/server.py

235 lines
6.6 KiB
Python
Raw Normal View History

2022-09-09 17:04:32 +02:00
import json
import traceback
import sys
import os
SCRIPT_DIR = os.getcwd()
print('started in ', SCRIPT_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
2022-09-09 17:04:32 +02:00
CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
2022-09-14 16:48:46 +02:00
# this is needed for development.
from fastapi.middleware.cors import CORSMiddleware
2022-09-14 18:57:03 +02:00
import logging
from sd_internal import Request, Response
app = FastAPI()
2022-09-14 16:48:46 +02:00
# we need to be able to run a local server for the UI (9001)
# and still be able to hit our python port (9000)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_loaded = False
model_is_loading = False
modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
# allow_nsfw: bool = False
2022-09-05 13:55:25 +02:00
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
2022-09-09 17:04:32 +02:00
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
@app.get('/')
def read_root():
2022-09-09 18:24:13 +02:00
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
2022-09-14 16:48:46 +02:00
return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers)
# then get the js files
@app.get('/index.js')
def read_scripts():
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.js'))
#then get the css files
@app.get('/index.css')
def read_styles():
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css'))
2022-09-14 18:57:03 +02:00
@app.get('/ping')
async def ping():
global model_loaded, model_is_loading
try:
if model_loaded:
return {'OK'}
if model_is_loading:
return {'ERROR'}
model_is_loading = True
from sd_internal import runtime
runtime.load_model_ckpt(ckpt_to_use="sd-v1-4")
model_loaded = True
model_is_loading = False
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.post('/image')
def image(req : ImageRequest):
from sd_internal import runtime
r = Request()
r.prompt = req.prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision
2022-09-05 13:55:25 +02:00
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.stream_progress_updates = req.stream_progress_updates
try:
res = runtime.mk_img(r)
if r.stream_progress_updates:
return StreamingResponse(res, media_type='application/json')
else:
return res.json()
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/stop')
def stop():
try:
if model_is_loading:
return {'ERROR'}
from sd_internal import runtime
runtime.stop_processing = True
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
2022-09-09 17:04:32 +02:00
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w') as f:
f.write(config_json_str)
with open(config_bat_path, 'w') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/app_config')
def getAppConfig():
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return HTTPException(status_code=500, detail="No config file")
2022-09-09 17:04:32 +02:00
with open(config_json_path, 'r') as f:
config_json_str = f.read()
config = json.loads(config_json_str)
return config
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
2022-09-14 16:48:46 +02:00
# moved these to the root for easier pathing
# TODO: change the vite config for public files
@app.get('/ding.mp3')
def read_ding():
2022-09-14 16:48:46 +02:00
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/ding.mp3'))
2022-09-14 16:48:46 +02:00
@app.get('/kofi.png')
2022-09-05 14:33:19 +02:00
def read_modifiers():
2022-09-14 16:48:46 +02:00
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/kofi.png'))
2022-09-05 14:33:19 +02:00
@app.get('/modifiers.json')
def read_modifiers():
2022-09-14 18:57:03 +02:00
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/modifiers.json'))
@app.get('/output_dir')
def read_home_dir():
return {outpath}
# don't log /ping requests
class HealthCheckLogFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.getMessage().find('/ping') == -1
logging.getLogger('uvicorn.access').addFilter(HealthCheckLogFilter())
# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')