mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-21 02:18:24 +02:00
sqlalchemy bucket v1
This commit is contained in:
parent
4b27b45e4c
commit
646aeac858
@ -25,6 +25,7 @@ modules_to_check = {
|
|||||||
"fastapi": "0.85.1",
|
"fastapi": "0.85.1",
|
||||||
"pycloudflared": "0.2.0",
|
"pycloudflared": "0.2.0",
|
||||||
"ruamel.yaml": "0.17.21",
|
"ruamel.yaml": "0.17.21",
|
||||||
|
"sqlalchemy": "2.0.19",
|
||||||
# "xformers": "0.0.16",
|
# "xformers": "0.0.16",
|
||||||
}
|
}
|
||||||
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]
|
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]
|
||||||
|
36
ui/easydiffusion/bucket_crud.py
Normal file
36
ui/easydiffusion/bucket_crud.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from easydiffusion import bucket_models, bucket_schemas
|
||||||
|
|
||||||
|
|
||||||
|
def get_bucket(db: Session, bucket_id: int):
|
||||||
|
return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.id == bucket_id).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_bucket_by_path(db: Session, path: str):
|
||||||
|
return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.path == path).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_buckets(db: Session, skip: int = 0, limit: int = 100):
|
||||||
|
return db.query(bucket_models.Bucket).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_bucket(db: Session, bucket: bucket_schemas.BucketCreate):
|
||||||
|
db_bucket = bucket_models.Bucket(path=bucket.path)
|
||||||
|
db.add(db_bucket)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_bucket)
|
||||||
|
return db_bucket
|
||||||
|
|
||||||
|
|
||||||
|
def get_bucketfiles(db: Session, skip: int = 0, limit: int = 100):
|
||||||
|
return db.query(bucket_models.BucketFile).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int):
|
||||||
|
db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id)
|
||||||
|
db.add(db_bucketfile)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_bucketfile)
|
||||||
|
return db_bucketfile
|
||||||
|
|
15
ui/easydiffusion/bucket_database.py
Normal file
15
ui/easydiffusion/bucket_database.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from easydiffusion import app
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
os.makedirs(app.BUCKET_DIR, exist_ok=True)
|
||||||
|
SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db")
|
||||||
|
print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL)
|
||||||
|
|
||||||
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
BucketBase = declarative_base()
|
@ -1,14 +1,57 @@
|
|||||||
import os
|
from typing import List
|
||||||
from easydiffusion import app
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from fastapi import Depends, FastAPI, HTTPException
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db")
|
from easydiffusion import bucket_crud, bucket_models, bucket_schemas
|
||||||
print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL)
|
from easydiffusion.bucket_database import SessionLocal, engine
|
||||||
|
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
Base = declarative_base()
|
def init():
|
||||||
|
from easydiffusion.server import server_api
|
||||||
|
|
||||||
|
bucket_models.BucketBase.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
# Dependency
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@server_api.post("/buckets/", response_model=bucket_schemas.Bucket)
|
||||||
|
def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)):
|
||||||
|
db_bucket = bucket_crud.get_bucket_by_path(db, path=bucket.path)
|
||||||
|
if db_bucket:
|
||||||
|
raise HTTPException(status_code=400, detail="Bucket already exists")
|
||||||
|
return bucket_crud.create_bucket(db=db, bucket=bucket)
|
||||||
|
|
||||||
|
@server_api.get("/buckets/", response_model=List[bucket_schemas.Bucket])
|
||||||
|
def read_bucket(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||||
|
buckets = bucket_crud.get_buckets(db, skip=skip, limit=limit)
|
||||||
|
return buckets
|
||||||
|
|
||||||
|
|
||||||
|
@server_api.get("/buckets/{bucket_id}", response_model=bucket_schemas.Bucket)
|
||||||
|
def read_bucket(bucket_id: int, db: Session = Depends(get_db)):
|
||||||
|
db_bucket = bucket_crud.get_bucket(db, bucket_id=bucket_id)
|
||||||
|
if db_bucket is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Bucket not found")
|
||||||
|
return db_bucket
|
||||||
|
|
||||||
|
|
||||||
|
@server_api.post("/buckets/{bucket_id}/items/", response_model=bucket_schemas.BucketFile)
|
||||||
|
def create_bucketfile_in_bucket(
|
||||||
|
bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
return bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
|
||||||
|
|
||||||
|
|
||||||
|
@server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile])
|
||||||
|
def read_bucketfiles(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||||
|
bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit)
|
||||||
|
return bucketfiles
|
||||||
|
|
||||||
|
25
ui/easydiffusion/bucket_models.py
Normal file
25
ui/easydiffusion/bucket_models.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, BLOB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from easydiffusion.bucket_database import BucketBase
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket(BucketBase):
|
||||||
|
__tablename__ = "bucket"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
path = Column(String, unique=True, index=True)
|
||||||
|
|
||||||
|
bucketfiles = relationship("BucketFile", back_populates="bucket")
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFile(BucketBase):
|
||||||
|
__tablename__ = "bucketfile"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
filename = Column(String, index=True)
|
||||||
|
data = Column(BLOB, index=False)
|
||||||
|
bucket_id = Column(Integer, ForeignKey("bucket.id"))
|
||||||
|
|
||||||
|
bucket = relationship("Bucket", back_populates="bucketfiles")
|
||||||
|
|
37
ui/easydiffusion/bucket_schemas.py
Normal file
37
ui/easydiffusion/bucket_schemas.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFileBase(BaseModel):
|
||||||
|
filename: str
|
||||||
|
data: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFileCreate(BucketFileBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFile(BucketFileBase):
|
||||||
|
id: int
|
||||||
|
bucket_id: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class BucketBase(BaseModel):
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
|
class BucketCreate(BucketBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket(BucketBase):
|
||||||
|
id: int
|
||||||
|
bucketfiles: List[BucketFile] = []
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
@ -8,7 +8,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from easydiffusion import app, model_manager, task_manager, bucket_manager
|
from easydiffusion import app, model_manager, task_manager
|
||||||
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from easydiffusion import model_manager, app, server
|
from easydiffusion import model_manager, app, server, bucket_manager
|
||||||
from easydiffusion.server import server_api # required for uvicorn
|
from easydiffusion.server import server_api # required for uvicorn
|
||||||
|
|
||||||
server.init()
|
server.init()
|
||||||
@ -7,6 +7,7 @@ server.init()
|
|||||||
model_manager.init()
|
model_manager.init()
|
||||||
app.init()
|
app.init()
|
||||||
app.init_render_threads()
|
app.init_render_threads()
|
||||||
|
bucket_manager.init()
|
||||||
|
|
||||||
# start the browser ui
|
# start the browser ui
|
||||||
app.open_browser()
|
app.open_browser()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user