diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 03e50db0..0f50e842 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -25,6 +25,7 @@ modules_to_check = { "fastapi": "0.85.1", "pycloudflared": "0.2.0", "ruamel.yaml": "0.17.21", + "sqlalchemy": "2.0.19", # "xformers": "0.0.16", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"] diff --git a/ui/easydiffusion/bucket_crud.py b/ui/easydiffusion/bucket_crud.py new file mode 100644 index 00000000..ece14fcc --- /dev/null +++ b/ui/easydiffusion/bucket_crud.py @@ -0,0 +1,36 @@ +from sqlalchemy.orm import Session + +from easydiffusion import bucket_models, bucket_schemas + + +def get_bucket(db: Session, bucket_id: int): + return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.id == bucket_id).first() + + +def get_bucket_by_path(db: Session, path: str): + return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.path == path).first() + + +def get_buckets(db: Session, skip: int = 0, limit: int = 100): + return db.query(bucket_models.Bucket).offset(skip).limit(limit).all() + + +def create_bucket(db: Session, bucket: bucket_schemas.BucketCreate): + db_bucket = bucket_models.Bucket(path=bucket.path) + db.add(db_bucket) + db.commit() + db.refresh(db_bucket) + return db_bucket + + +def get_bucketfiles(db: Session, skip: int = 0, limit: int = 100): + return db.query(bucket_models.BucketFile).offset(skip).limit(limit).all() + + +def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int): + db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id) + db.add(db_bucketfile) + db.commit() + db.refresh(db_bucketfile) + return db_bucketfile + diff --git a/ui/easydiffusion/bucket_database.py b/ui/easydiffusion/bucket_database.py new file mode 100644 index 00000000..e3c92845 --- /dev/null +++ b/ui/easydiffusion/bucket_database.py @@ -0,0 +1,15 @@ +import os +from easydiffusion import app + +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +os.makedirs(app.BUCKET_DIR, exist_ok=True) +SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db") +print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL) + +engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +BucketBase = declarative_base() diff --git a/ui/easydiffusion/bucket_manager.py b/ui/easydiffusion/bucket_manager.py index e24b0cc0..0985573f 100644 --- a/ui/easydiffusion/bucket_manager.py +++ b/ui/easydiffusion/bucket_manager.py @@ -1,14 +1,57 @@ -import os -from easydiffusion import app +from typing import List -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from fastapi import Depends, FastAPI, HTTPException +from sqlalchemy.orm import Session -SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db") -print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL) +from easydiffusion import bucket_crud, bucket_models, bucket_schemas +from easydiffusion.bucket_database import SessionLocal, engine -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() +def init(): + from easydiffusion.server import server_api + + bucket_models.BucketBase.metadata.create_all(bind=engine) + + + # Dependency + def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + + @server_api.post("/buckets/", response_model=bucket_schemas.Bucket) + def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)): + db_bucket = bucket_crud.get_bucket_by_path(db, path=bucket.path) + if db_bucket: + raise HTTPException(status_code=400, detail="Bucket already exists") + return bucket_crud.create_bucket(db=db, bucket=bucket) + + @server_api.get("/buckets/", response_model=List[bucket_schemas.Bucket]) + def read_bucket(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + buckets = bucket_crud.get_buckets(db, skip=skip, limit=limit) + return buckets + + + @server_api.get("/buckets/{bucket_id}", response_model=bucket_schemas.Bucket) + def read_bucket(bucket_id: int, db: Session = Depends(get_db)): + db_bucket = bucket_crud.get_bucket(db, bucket_id=bucket_id) + if db_bucket is None: + raise HTTPException(status_code=404, detail="Bucket not found") + return db_bucket + + + @server_api.post("/buckets/{bucket_id}/items/", response_model=bucket_schemas.BucketFile) + def create_bucketfile_in_bucket( + bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db) + ): + return bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + + + @server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile]) + def read_bucketfiles(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit) + return bucketfiles + diff --git a/ui/easydiffusion/bucket_models.py b/ui/easydiffusion/bucket_models.py new file mode 100644 index 00000000..3ede1a36 --- /dev/null +++ b/ui/easydiffusion/bucket_models.py @@ -0,0 +1,25 @@ +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, BLOB +from sqlalchemy.orm import relationship + +from easydiffusion.bucket_database import BucketBase + + +class Bucket(BucketBase): + __tablename__ = "bucket" + + id = Column(Integer, primary_key=True, index=True) + path = Column(String, unique=True, index=True) + + bucketfiles = relationship("BucketFile", back_populates="bucket") + + +class BucketFile(BucketBase): + __tablename__ = "bucketfile" + + id = Column(Integer, primary_key=True, index=True) + filename = Column(String, index=True) + data = Column(BLOB, index=False) + bucket_id = Column(Integer, ForeignKey("bucket.id")) + + bucket = relationship("Bucket", back_populates="bucketfiles") + diff --git a/ui/easydiffusion/bucket_schemas.py b/ui/easydiffusion/bucket_schemas.py new file mode 100644 index 00000000..9f7e2377 --- /dev/null +++ b/ui/easydiffusion/bucket_schemas.py @@ -0,0 +1,37 @@ +from typing import List, Union + +from pydantic import BaseModel + + +class BucketFileBase(BaseModel): + filename: str + data: bytes + + +class BucketFileCreate(BucketFileBase): + pass + + +class BucketFile(BucketFileBase): + id: int + bucket_id: int + + class Config: + orm_mode = True + + +class BucketBase(BaseModel): + path: str + + +class BucketCreate(BucketBase): + pass + + +class Bucket(BucketBase): + id: int + bucketfiles: List[BucketFile] = [] + + class Config: + orm_mode = True + diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 4705d112..df788b0c 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -8,7 +8,7 @@ import os import traceback from typing import List, Union -from easydiffusion import app, model_manager, task_manager, bucket_manager +from easydiffusion import app, model_manager, task_manager from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData from easydiffusion.utils import log from fastapi import FastAPI, HTTPException diff --git a/ui/main.py b/ui/main.py index f5998622..a7568ba6 100644 --- a/ui/main.py +++ b/ui/main.py @@ -1,4 +1,4 @@ -from easydiffusion import model_manager, app, server +from easydiffusion import model_manager, app, server, bucket_manager from easydiffusion.server import server_api # required for uvicorn server.init() @@ -7,6 +7,7 @@ server.init() model_manager.init() app.init() app.init_render_threads() +bucket_manager.init() # start the browser ui app.open_browser()