diff --git a/scripts/check_modules.py b/scripts/check_modules.py index a5edf5f5..aecf7576 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -25,6 +25,8 @@ modules_to_check = { "fastapi": "0.85.1", "pycloudflared": "0.2.0", "ruamel.yaml": "0.17.21", + "sqlalchemy": "2.0.19", + "python-multipart": "0.0.6", # "xformers": "0.0.16", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"] diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index aa1c5ba7..e2c190f8 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -38,6 +38,7 @@ SD_UI_DIR = os.getenv("SD_UI_PATH", None) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) +BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) diff --git a/ui/easydiffusion/bucket_manager.py b/ui/easydiffusion/bucket_manager.py new file mode 100644 index 00000000..4400fd17 --- /dev/null +++ b/ui/easydiffusion/bucket_manager.py @@ -0,0 +1,127 @@ +from typing import List + +from fastapi import Depends, FastAPI, HTTPException, Response, File +from fastapi.responses import FileResponse +from sqlalchemy.orm import Session + +from easydiffusion.easydb import crud, models, schemas +from easydiffusion.easydb.database import SessionLocal, engine + +from requests.compat import urlparse +from os.path import abspath + +import base64, json + +MIME_TYPES = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "gif": "image/gif", + "png": "image/png", + "webp": "image/webp", + "js": "text/javascript", + "htm": "text/html", + "html": "text/html", + "css": "text/css", + "json": "application/json", + "mjs": "application/json", + "yaml": "application/yaml", + "svg": "image/svg+xml", + "txt": "text/plain", +} + +def init(): + from easydiffusion.server import server_api + + models.BucketBase.metadata.create_all(bind=engine) + + + # Dependency + def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + @server_api.get("/bucket/{obj_path:path}") + def bucket_get_object(obj_path: str, db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + + if filename==None: + bucket = crud.get_bucket_by_path(db, path=path) + if bucket == None: + raise HTTPException(status_code=404, detail="Bucket not found") + bucketfiles = db.query(models.BucketFile).with_entities(models.BucketFile.filename).filter(models.BucketFile.bucket_id == bucket.id).all() + bucketfiles = [ x.filename for x in bucketfiles ] + return bucketfiles + + else: + bucket_id = crud.get_bucket_by_path(db, path).id + bucketfile = db.query(models.BucketFile).filter(models.BucketFile.bucket_id == bucket_id, models.BucketFile.filename == filename).first() + + suffix = get_suffix_from_filename(filename) + + return Response(content=bucketfile.data, media_type=MIME_TYPES.get(suffix, "application/octet-stream")) + + @server_api.post("/bucket/{obj_path:path}") + def bucket_post_object(obj_path: str, file: bytes = File(), db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + bucket = crud.get_bucket_by_path(db, path) + + if bucket == None: + bucket_id = crud.create_bucket(db=db, bucket=schemas.BucketCreate(path=path)) + else: + bucket_id = bucket.id + + bucketfile = schemas.BucketFileCreate(filename=filename, data=file) + result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result + + + @server_api.post("/buckets/{bucket_id}/items/", response_model=schemas.BucketFile) + def create_bucketfile_in_bucket( + bucket_id: int, bucketfile: schemas.BucketFileCreate, db: Session = Depends(get_db) + ): + bucketfile.data = base64.decodestring(bucketfile.data) + result = crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result + + @server_api.get("/image/{image_path:path}") + def get_image(image_path: str, db: Session = Depends(get_db)): + from easydiffusion.easydb.mappings import Image + image_path = str(abspath(image_path)) + amount = len(db.query(Image).filter(Image.path == image_path).all()) + if amount > 0: + image = db.query(Image).filter(Image.path == image_path).first() + return FileResponse(image.path) + else: + raise HTTPException(status_code=404, detail="Image not found") + + @server_api.get("/all_images") + def get_all_images(db: Session = Depends(get_db)): + from easydiffusion.easydb.mappings import Image + images = db.query(Image).all() + sum_string = "
" + for img in images: + options = f"Path: {img.path}\nPrompt: {img.prompt}\nNegative Prompt: {img.negative_prompt}\nSeed: {img.seed}\nModel: {img.use_stable_diffusion_model}\nSize: {img.height}x{img.width}\nSampler: {img.sampler_name}\nSteps: {img.num_inference_steps}\nGuidance Scale: {img.guidance_scale}\nLoRA: {img.lora}\nUpscaling: {img.use_upscale}\nFace Correction: {img.use_face_correction}\n" + sum_string += f"" + sum_string += "
" + return Response(content=sum_string, media_type="text/html") + + +def get_filename_from_url(url): + path = urlparse(url).path + name = path[path.rfind('/')+1:] + return name or None + +def get_path_from_url(url): + path = urlparse(url).path + path = path[0:path.rfind('/')] + return path or None + +def get_suffix_from_filename(filename): + return filename[filename.rfind('.')+1:] diff --git a/ui/easydiffusion/easydb/crud.py b/ui/easydiffusion/easydb/crud.py new file mode 100644 index 00000000..7550a52a --- /dev/null +++ b/ui/easydiffusion/easydb/crud.py @@ -0,0 +1,25 @@ +from sqlalchemy.orm import Session + +from easydiffusion.easydb import models, schemas + + +def get_bucket_by_path(db: Session, path: str): + return db.query(models.Bucket).filter(models.Bucket.path == path).first() + + +def create_bucket(db: Session, bucket: schemas.BucketCreate): + db_bucket = models.Bucket(path=bucket.path) + db.add(db_bucket) + db.commit() + db.refresh(db_bucket) + return db_bucket + + +def create_bucketfile(db: Session, bucketfile: schemas.BucketFileCreate, bucket_id: int): + db_bucketfile = models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id) + db.merge(db_bucketfile) + db.commit() + from pprint import pprint + db_bucketfile = db.query(models.BucketFile).filter(models.BucketFile.bucket_id==bucket_id, models.BucketFile.filename==bucketfile.filename).first() + return db_bucketfile + diff --git a/ui/easydiffusion/easydb/database.py b/ui/easydiffusion/easydb/database.py new file mode 100644 index 00000000..e3c92845 --- /dev/null +++ b/ui/easydiffusion/easydb/database.py @@ -0,0 +1,15 @@ +import os +from easydiffusion import app + +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +os.makedirs(app.BUCKET_DIR, exist_ok=True) +SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db") +print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL) + +engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +BucketBase = declarative_base() diff --git a/ui/easydiffusion/easydb/mappings.py b/ui/easydiffusion/easydb/mappings.py new file mode 100644 index 00000000..ad68ecab --- /dev/null +++ b/ui/easydiffusion/easydb/mappings.py @@ -0,0 +1,32 @@ +from sqlalchemy import Column, Integer, String, Float, Boolean +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Image(Base): + __tablename__ = 'images' + + path = Column(String, primary_key=True) + seed = Column(Integer) + use_stable_diffusion_model = Column(String) + clip_skip = Column(Boolean) + use_vae_model = Column(String) + sampler_name = Column(String) + width = Column(Integer) + height = Column(Integer) + num_inference_steps = Column(Integer) + guidance_scale = Column(Float) + lora = Column(String) + use_hypernetwork_model = Column(String) + tiling = Column(String) + use_face_correction = Column(String) + use_upscale = Column(String) + prompt = Column(String) + negative_prompt = Column(String) + + def __repr__(self): + return "" % ( + self.path, self.seed, self.use_stable_diffusion_model, self.clip_skip, self.use_vae_model, self.sampler_name, self.width, self.height, self.num_inference_steps, self.guidance_scale, self.lora, self.use_hypernetwork_model, self.tiling, self.use_face_correction, self.use_upscale, self.prompt, self.negative_prompt) + +from easydiffusion.easydb.database import engine +Image.metadata.create_all(engine) \ No newline at end of file diff --git a/ui/easydiffusion/easydb/models.py b/ui/easydiffusion/easydb/models.py new file mode 100644 index 00000000..04834951 --- /dev/null +++ b/ui/easydiffusion/easydb/models.py @@ -0,0 +1,25 @@ +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, BLOB +from sqlalchemy.orm import relationship + +from easydiffusion.easydb.database import BucketBase + + +class Bucket(BucketBase): + __tablename__ = "bucket" + + id = Column(Integer, primary_key=True, index=True) + path = Column(String, unique=True, index=True) + + bucketfiles = relationship("BucketFile", back_populates="bucket") + + +class BucketFile(BucketBase): + __tablename__ = "bucketfile" + + filename = Column(String, index=True, primary_key=True) + bucket_id = Column(Integer, ForeignKey("bucket.id"), primary_key=True) + + data = Column(BLOB, index=False) + + bucket = relationship("Bucket", back_populates="bucketfiles") + diff --git a/ui/easydiffusion/easydb/schemas.py b/ui/easydiffusion/easydb/schemas.py new file mode 100644 index 00000000..68bc04e2 --- /dev/null +++ b/ui/easydiffusion/easydb/schemas.py @@ -0,0 +1,36 @@ +from typing import List, Union + +from pydantic import BaseModel + + +class BucketFileBase(BaseModel): + filename: str + data: bytes + + +class BucketFileCreate(BucketFileBase): + pass + + +class BucketFile(BucketFileBase): + bucket_id: int + + class Config: + orm_mode = True + + +class BucketBase(BaseModel): + path: str + + +class BucketCreate(BucketBase): + pass + + +class Bucket(BucketBase): + id: int + bucketfiles: List[BucketFile] = [] + + class Config: + orm_mode = True + diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 090aef0f..a21e32b8 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -142,6 +142,47 @@ def save_images_to_disk( output_quality=output_format.output_quality, output_lossless=output_format.output_lossless, ) + + for i in range(len(filtered_images)): + path_i = f"{os.path.join(save_dir_path, make_filename(i))}.{output_format.output_format.lower()}" + + def createLoraString(metadata_entries, i): + if metadata_entries[i]["use_lora_model"] is None: + return "None" + elif isinstance(metadata_entries[i]["use_lora_model"], list): + loraString = "" + for j in range(len(metadata_entries[i]["use_lora_model"])): + loraString += metadata_entries[i]["use_lora_model"][j] + ":" + str(metadata_entries[i]["lora_alpha"][j]) + " " + return loraString.trim() + else: + return metadata_entries[i]["use_lora_model"] + ":" + str(metadata_entries[i]["lora_alpha"]) + + from easydiffusion.easydb.mappings import Image + from easydiffusion.easydb.database import SessionLocal + + session = SessionLocal() + session.add(Image( + path = path_i, + seed = metadata_entries[i]["seed"], + use_stable_diffusion_model = metadata_entries[i]["use_stable_diffusion_model"], + clip_skip = metadata_entries[i]["clip_skip"], + use_vae_model = metadata_entries[i]["use_vae_model"], + sampler_name = metadata_entries[i]["sampler_name"], + width = metadata_entries[i]["width"], + height = metadata_entries[i]["height"], + num_inference_steps = metadata_entries[i]["num_inference_steps"], + guidance_scale = metadata_entries[i]["guidance_scale"], + lora = createLoraString(metadata_entries, i), + use_hypernetwork_model = metadata_entries[i]["use_hypernetwork_model"], + tiling = metadata_entries[i]["tiling"], + use_face_correction = metadata_entries[i]["use_face_correction"], + use_upscale = metadata_entries[i]["use_upscale"], + prompt = metadata_entries[i]["prompt"], + negative_prompt = metadata_entries[i]["negative_prompt"] + )) + session.commit() + session.close() + if task_data.metadata_output_format: for metadata_output_format in task_data.metadata_output_format.split(","): if metadata_output_format.lower() in ["json", "txt", "embed"]: diff --git a/ui/index.html b/ui/index.html index d5ee2ae6..0531933c 100644 --- a/ui/index.html +++ b/ui/index.html @@ -49,6 +49,9 @@ Help & Community + + Gallery + @@ -511,6 +514,10 @@ +