diff --git a/scripts/check_modules.py b/scripts/check_modules.py
index a5edf5f5..aecf7576 100644
--- a/scripts/check_modules.py
+++ b/scripts/check_modules.py
@@ -25,6 +25,8 @@ modules_to_check = {
"fastapi": "0.85.1",
"pycloudflared": "0.2.0",
"ruamel.yaml": "0.17.21",
+ "sqlalchemy": "2.0.19",
+ "python-multipart": "0.0.6",
# "xformers": "0.0.16",
}
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]
diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py
index aa1c5ba7..e2c190f8 100644
--- a/ui/easydiffusion/app.py
+++ b/ui/easydiffusion/app.py
@@ -38,6 +38,7 @@ SD_UI_DIR = os.getenv("SD_UI_PATH", None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
+BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket"))
USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins"))
CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins"))
diff --git a/ui/easydiffusion/bucket_manager.py b/ui/easydiffusion/bucket_manager.py
new file mode 100644
index 00000000..4400fd17
--- /dev/null
+++ b/ui/easydiffusion/bucket_manager.py
@@ -0,0 +1,127 @@
+from typing import List
+
+from fastapi import Depends, FastAPI, HTTPException, Response, File
+from fastapi.responses import FileResponse
+from sqlalchemy.orm import Session
+
+from easydiffusion.easydb import crud, models, schemas
+from easydiffusion.easydb.database import SessionLocal, engine
+
+from requests.compat import urlparse
+from os.path import abspath
+
+import base64, json
+
+MIME_TYPES = {
+ "jpg": "image/jpeg",
+ "jpeg": "image/jpeg",
+ "gif": "image/gif",
+ "png": "image/png",
+ "webp": "image/webp",
+ "js": "text/javascript",
+ "htm": "text/html",
+ "html": "text/html",
+ "css": "text/css",
+ "json": "application/json",
+ "mjs": "application/json",
+ "yaml": "application/yaml",
+ "svg": "image/svg+xml",
+ "txt": "text/plain",
+}
+
+def init():
+ from easydiffusion.server import server_api
+
+ models.BucketBase.metadata.create_all(bind=engine)
+
+
+ # Dependency
+ def get_db():
+ db = SessionLocal()
+ try:
+ yield db
+ finally:
+ db.close()
+
+ @server_api.get("/bucket/{obj_path:path}")
+ def bucket_get_object(obj_path: str, db: Session = Depends(get_db)):
+ filename = get_filename_from_url(obj_path)
+ path = get_path_from_url(obj_path)
+
+ if filename==None:
+ bucket = crud.get_bucket_by_path(db, path=path)
+ if bucket == None:
+ raise HTTPException(status_code=404, detail="Bucket not found")
+ bucketfiles = db.query(models.BucketFile).with_entities(models.BucketFile.filename).filter(models.BucketFile.bucket_id == bucket.id).all()
+ bucketfiles = [ x.filename for x in bucketfiles ]
+ return bucketfiles
+
+ else:
+ bucket_id = crud.get_bucket_by_path(db, path).id
+ bucketfile = db.query(models.BucketFile).filter(models.BucketFile.bucket_id == bucket_id, models.BucketFile.filename == filename).first()
+
+ suffix = get_suffix_from_filename(filename)
+
+ return Response(content=bucketfile.data, media_type=MIME_TYPES.get(suffix, "application/octet-stream"))
+
+ @server_api.post("/bucket/{obj_path:path}")
+ def bucket_post_object(obj_path: str, file: bytes = File(), db: Session = Depends(get_db)):
+ filename = get_filename_from_url(obj_path)
+ path = get_path_from_url(obj_path)
+ bucket = crud.get_bucket_by_path(db, path)
+
+ if bucket == None:
+ bucket_id = crud.create_bucket(db=db, bucket=schemas.BucketCreate(path=path))
+ else:
+ bucket_id = bucket.id
+
+ bucketfile = schemas.BucketFileCreate(filename=filename, data=file)
+ result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
+ result.data = base64.encodestring(result.data)
+ return result
+
+
+ @server_api.post("/buckets/{bucket_id}/items/", response_model=schemas.BucketFile)
+ def create_bucketfile_in_bucket(
+ bucket_id: int, bucketfile: schemas.BucketFileCreate, db: Session = Depends(get_db)
+ ):
+ bucketfile.data = base64.decodestring(bucketfile.data)
+ result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
+ result.data = base64.encodestring(result.data)
+ return result
+
+ @server_api.get("/image/{image_path:path}")
+ def get_image(image_path: str, db: Session = Depends(get_db)):
+ from easydiffusion.easydb.mappings import Image
+ image_path = str(abspath(image_path))
+ amount = len(db.query(Image).filter(Image.path == image_path).all())
+ if amount > 0:
+ image = db.query(Image).filter(Image.path == image_path).first()
+ return FileResponse(image.path)
+ else:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ @server_api.get("/all_images")
+ def get_all_images(db: Session = Depends(get_db)):
+ from easydiffusion.easydb.mappings import Image
+ images = db.query(Image).all()
+ sum_string = "
"
+ for img in images:
+ options = f"Path: {img.path}\nPrompt: {img.prompt}\nNegative Prompt: {img.negative_prompt}\nSeed: {img.seed}\nModel: {img.use_stable_diffusion_model}\nSize: {img.height}x{img.width}\nSampler: {img.sampler_name}\nSteps: {img.num_inference_steps}\nGuidance Scale: {img.guidance_scale}\nLoRA: {img.lora}\nUpscaling: {img.use_upscale}\nFace Correction: {img.use_face_correction}\n"
+ sum_string += f"

"
+ sum_string += "
"
+ return Response(content=sum_string, media_type="text/html")
+
+
+def get_filename_from_url(url):
+ path = urlparse(url).path
+ name = path[path.rfind('/')+1:]
+ return name or None
+
+def get_path_from_url(url):
+ path = urlparse(url).path
+ path = path[0:path.rfind('/')]
+ return path or None
+
+def get_suffix_from_filename(filename):
+ return filename[filename.rfind('.')+1:]
diff --git a/ui/easydiffusion/easydb/crud.py b/ui/easydiffusion/easydb/crud.py
new file mode 100644
index 00000000..7550a52a
--- /dev/null
+++ b/ui/easydiffusion/easydb/crud.py
@@ -0,0 +1,25 @@
+from sqlalchemy.orm import Session
+
+from easydiffusion.easydb import models, schemas
+
+
+def get_bucket_by_path(db: Session, path: str):
+ return db.query(models.Bucket).filter(models.Bucket.path == path).first()
+
+
+def create_bucket(db: Session, bucket: schemas.BucketCreate):
+ db_bucket = models.Bucket(path=bucket.path)
+ db.add(db_bucket)
+ db.commit()
+ db.refresh(db_bucket)
+ return db_bucket
+
+
+def create_bucketfile(db: Session, bucketfile: schemas.BucketFileCreate, bucket_id: int):
+ db_bucketfile = models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id)
+ db.merge(db_bucketfile)
+ db.commit()
+ from pprint import pprint
+ db_bucketfile = db.query(models.BucketFile).filter(models.BucketFile.bucket_id==bucket_id, models.BucketFile.filename==bucketfile.filename).first()
+ return db_bucketfile
+
diff --git a/ui/easydiffusion/easydb/database.py b/ui/easydiffusion/easydb/database.py
new file mode 100644
index 00000000..e3c92845
--- /dev/null
+++ b/ui/easydiffusion/easydb/database.py
@@ -0,0 +1,15 @@
+import os
+from easydiffusion import app
+
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+os.makedirs(app.BUCKET_DIR, exist_ok=True)
+SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db")
+print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL)
+
+engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+BucketBase = declarative_base()
diff --git a/ui/easydiffusion/easydb/mappings.py b/ui/easydiffusion/easydb/mappings.py
new file mode 100644
index 00000000..ad68ecab
--- /dev/null
+++ b/ui/easydiffusion/easydb/mappings.py
@@ -0,0 +1,32 @@
+from sqlalchemy import Column, Integer, String, Float, Boolean
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+class Image(Base):
+ __tablename__ = 'images'
+
+ path = Column(String, primary_key=True)
+ seed = Column(Integer)
+ use_stable_diffusion_model = Column(String)
+ clip_skip = Column(Boolean)
+ use_vae_model = Column(String)
+ sampler_name = Column(String)
+ width = Column(Integer)
+ height = Column(Integer)
+ num_inference_steps = Column(Integer)
+ guidance_scale = Column(Float)
+ lora = Column(String)
+ use_hypernetwork_model = Column(String)
+ tiling = Column(String)
+ use_face_correction = Column(String)
+ use_upscale = Column(String)
+ prompt = Column(String)
+ negative_prompt = Column(String)
+
+ def __repr__(self):
+ return "" % (
+ self.path, self.seed, self.use_stable_diffusion_model, self.clip_skip, self.use_vae_model, self.sampler_name, self.width, self.height, self.num_inference_steps, self.guidance_scale, self.lora, self.use_hypernetwork_model, self.tiling, self.use_face_correction, self.use_upscale, self.prompt, self.negative_prompt)
+
+from easydiffusion.easydb.database import engine
+Image.metadata.create_all(engine)
\ No newline at end of file
diff --git a/ui/easydiffusion/easydb/models.py b/ui/easydiffusion/easydb/models.py
new file mode 100644
index 00000000..04834951
--- /dev/null
+++ b/ui/easydiffusion/easydb/models.py
@@ -0,0 +1,25 @@
+from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, BLOB
+from sqlalchemy.orm import relationship
+
+from easydiffusion.easydb.database import BucketBase
+
+
+class Bucket(BucketBase):
+ __tablename__ = "bucket"
+
+ id = Column(Integer, primary_key=True, index=True)
+ path = Column(String, unique=True, index=True)
+
+ bucketfiles = relationship("BucketFile", back_populates="bucket")
+
+
+class BucketFile(BucketBase):
+ __tablename__ = "bucketfile"
+
+ filename = Column(String, index=True, primary_key=True)
+ bucket_id = Column(Integer, ForeignKey("bucket.id"), primary_key=True)
+
+ data = Column(BLOB, index=False)
+
+ bucket = relationship("Bucket", back_populates="bucketfiles")
+
diff --git a/ui/easydiffusion/easydb/schemas.py b/ui/easydiffusion/easydb/schemas.py
new file mode 100644
index 00000000..68bc04e2
--- /dev/null
+++ b/ui/easydiffusion/easydb/schemas.py
@@ -0,0 +1,36 @@
+from typing import List, Union
+
+from pydantic import BaseModel
+
+
+class BucketFileBase(BaseModel):
+ filename: str
+ data: bytes
+
+
+class BucketFileCreate(BucketFileBase):
+ pass
+
+
+class BucketFile(BucketFileBase):
+ bucket_id: int
+
+ class Config:
+ orm_mode = True
+
+
+class BucketBase(BaseModel):
+ path: str
+
+
+class BucketCreate(BucketBase):
+ pass
+
+
+class Bucket(BucketBase):
+ id: int
+ bucketfiles: List[BucketFile] = []
+
+ class Config:
+ orm_mode = True
+
diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py
index 090aef0f..a21e32b8 100644
--- a/ui/easydiffusion/utils/save_utils.py
+++ b/ui/easydiffusion/utils/save_utils.py
@@ -142,6 +142,47 @@ def save_images_to_disk(
output_quality=output_format.output_quality,
output_lossless=output_format.output_lossless,
)
+
+ for i in range(len(filtered_images)):
+ path_i = f"{os.path.join(save_dir_path, make_filename(i))}.{output_format.output_format.lower()}"
+
+ def createLoraString(metadata_entries, i):
+ if metadata_entries[i]["use_lora_model"] is None:
+ return "None"
+ elif isinstance(metadata_entries[i]["use_lora_model"], list):
+ loraString = ""
+ for j in range(len(metadata_entries[i]["use_lora_model"])):
+ loraString += metadata_entries[i]["use_lora_model"][j] + ":" + str(metadata_entries[i]["lora_alpha"][j]) + " "
+ return loraString.trim()
+ else:
+ return metadata_entries[i]["use_lora_model"] + ":" + str(metadata_entries[i]["lora_alpha"])
+
+ from easydiffusion.easydb.mappings import Image
+ from easydiffusion.easydb.database import SessionLocal
+
+ session = SessionLocal()
+ session.add(Image(
+ path = path_i,
+ seed = metadata_entries[i]["seed"],
+ use_stable_diffusion_model = metadata_entries[i]["use_stable_diffusion_model"],
+ clip_skip = metadata_entries[i]["clip_skip"],
+ use_vae_model = metadata_entries[i]["use_vae_model"],
+ sampler_name = metadata_entries[i]["sampler_name"],
+ width = metadata_entries[i]["width"],
+ height = metadata_entries[i]["height"],
+ num_inference_steps = metadata_entries[i]["num_inference_steps"],
+ guidance_scale = metadata_entries[i]["guidance_scale"],
+ lora = createLoraString(metadata_entries, i),
+ use_hypernetwork_model = metadata_entries[i]["use_hypernetwork_model"],
+ tiling = metadata_entries[i]["tiling"],
+ use_face_correction = metadata_entries[i]["use_face_correction"],
+ use_upscale = metadata_entries[i]["use_upscale"],
+ prompt = metadata_entries[i]["prompt"],
+ negative_prompt = metadata_entries[i]["negative_prompt"]
+ ))
+ session.commit()
+ session.close()
+
if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","):
if metadata_output_format.lower() in ["json", "txt", "embed"]:
diff --git a/ui/index.html b/ui/index.html
index d5ee2ae6..0531933c 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -49,6 +49,9 @@
Help & Community
+
+ Gallery
+
@@ -511,6 +514,10 @@
+