diff --git a/scripts/check_modules.py b/scripts/check_modules.py index bad46943..301b6163 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -25,6 +25,8 @@ modules_to_check = { "fastapi": "0.85.1", "pycloudflared": "0.2.0", "ruamel.yaml": "0.17.21", + "sqlalchemy": "2.0.19", + "python-multipart": "0.0.6", # "xformers": "0.0.16", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"] diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index e181f9b8..6f8d731e 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -38,6 +38,7 @@ SD_UI_DIR = os.getenv("SD_UI_PATH", None) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) +BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) diff --git a/ui/easydiffusion/bucket_manager.py b/ui/easydiffusion/bucket_manager.py new file mode 100644 index 00000000..f587bc7f --- /dev/null +++ b/ui/easydiffusion/bucket_manager.py @@ -0,0 +1,103 @@ +from typing import List + +from fastapi import Depends, FastAPI, HTTPException, Response, File +from sqlalchemy.orm import Session + +from easydiffusion.easydb import crud, models, schemas +from easydiffusion.easydb.database import SessionLocal, engine + +from requests.compat import urlparse + +import base64, json + +MIME_TYPES = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "gif": "image/gif", + "png": "image/png", + "webp": "image/webp", + "js": "text/javascript", + "htm": "text/html", + "html": "text/html", + "css": "text/css", + "json": "application/json", + "mjs": "application/json", + "yaml": "application/yaml", + "svg": "image/svg+xml", + "txt": "text/plain", +} + +def init(): + from easydiffusion.server import server_api + + models.BucketBase.metadata.create_all(bind=engine) + + + # Dependency + def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + @server_api.get("/bucket/{obj_path:path}") + def bucket_get_object(obj_path: str, db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + + if filename==None: + bucket = crud.get_bucket_by_path(db, path=path) + if bucket == None: + raise HTTPException(status_code=404, detail="Bucket not found") + bucketfiles = db.query(models.BucketFile).with_entities(models.BucketFile.filename).filter(models.BucketFile.bucket_id == bucket.id).all() + bucketfiles = [ x.filename for x in bucketfiles ] + return bucketfiles + + else: + bucket_id = crud.get_bucket_by_path(db, path).id + bucketfile = db.query(models.BucketFile).filter(models.BucketFile.bucket_id == bucket_id, models.BucketFile.filename == filename).first() + + suffix = get_suffix_from_filename(filename) + + return Response(content=bucketfile.data, media_type=MIME_TYPES.get(suffix, "application/octet-stream")) + + @server_api.post("/bucket/{obj_path:path}") + def bucket_post_object(obj_path: str, file: bytes = File(), db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + bucket = crud.get_bucket_by_path(db, path) + + if bucket == None: + bucket_id = crud.create_bucket(db=db, bucket=schemas.BucketCreate(path=path)) + else: + bucket_id = bucket.id + + bucketfile = schemas.BucketFileCreate(filename=filename, data=file) + result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result + + + @server_api.post("/buckets/{bucket_id}/items/", response_model=schemas.BucketFile) + def create_bucketfile_in_bucket( + bucket_id: int, bucketfile: schemas.BucketFileCreate, db: Session = Depends(get_db) + ): + bucketfile.data = base64.decodestring(bucketfile.data) + result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result + + +def get_filename_from_url(url): + path = urlparse(url).path + name = path[path.rfind('/')+1:] + return name or None + +def get_path_from_url(url): + path = urlparse(url).path + path = path[0:path.rfind('/')] + return path or None + +def get_suffix_from_filename(filename): + return filename[filename.rfind('.')+1:] diff --git a/ui/main.py b/ui/main.py index f5998622..a7568ba6 100644 --- a/ui/main.py +++ b/ui/main.py @@ -1,4 +1,4 @@ -from easydiffusion import model_manager, app, server +from easydiffusion import model_manager, app, server, bucket_manager from easydiffusion.server import server_api # required for uvicorn server.init() @@ -7,6 +7,7 @@ server.init() model_manager.init() app.init() app.init_render_threads() +bucket_manager.init() # start the browser ui app.open_browser() diff --git a/ui/media/images/noimg.png b/ui/media/images/noimg.png new file mode 100644 index 00000000..33d7332f Binary files /dev/null and b/ui/media/images/noimg.png differ diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 1c76543a..d8c1b614 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -529,6 +529,7 @@ function showImages(reqBody, res, outputContainer, livePreview) { { text: "Upscale", on_click: onUpscaleClick }, { text: "Fix Faces", on_click: onFixFacesClick }, ], + { text: "Use as Thumbnail", on_click: onUseAsThumbnailClick }, ] // include the plugins @@ -669,6 +670,20 @@ function onMakeSimilarClick(req, img) { createTask(newTaskRequest) } +function onUseAsThumbnailClick(req, img) { + console.log(req) + console.log(img) + let embedding = prompt("Embedding name") + fetch(img.src) + .then(response => response.blob()) + .then(async function(blob) { + const formData = new FormData() + formData.append("file", blob) + const response = await fetch(`bucket/embeddings/${embedding}.jpg`, { method: 'POST', body: formData }); + console.log(response) + }) +} + function enqueueImageVariationTask(req, img, reqDiff) { const imageSeed = img.getAttribute("data-seed") @@ -2356,19 +2371,27 @@ document.getElementById("toggle-tensorrt-install").addEventListener("click", fun /* Embeddings */ +let icl = [] function updateEmbeddingsList(filter = "") { - function html(model, prefix = "", filter = "") { + function html(model, iconlist = [], prefix = "", filter = "") { filter = filter.toLowerCase() let toplevel = "" let folders = "" + console.log(iconlist) + let embIcon = Object.assign({}, ...iconlist.map( x=> ({[x.toLowerCase().split('.').slice(0,-1).join('.')]:x}))) model?.forEach((m) => { if (typeof m == "string") { - if (m.toLowerCase().search(filter) != -1) { - toplevel += ` ` + let token=m.toLowerCase() + if (token.search(filter) != -1) { + let img = '/media/images/noimg.png' + if (token in embIcon) { + img = `/bucket/embeddings/${embIcon[token]}` + } + toplevel += ` ` } } else { - let subdir = html(m[1], prefix + m[0] + "/", filter) + let subdir = html(m[1], iconlist, prefix + m[0] + "/", filter) if (subdir != "") { folders += `