sqlalchemy bucket v1

This commit is contained in:
JeLuF 2023-07-26 01:16:27 +02:00
parent 4b27b45e4c
commit 646aeac858
8 changed files with 170 additions and 12 deletions

View File

@ -25,6 +25,7 @@ modules_to_check = {
"fastapi": "0.85.1", "fastapi": "0.85.1",
"pycloudflared": "0.2.0", "pycloudflared": "0.2.0",
"ruamel.yaml": "0.17.21", "ruamel.yaml": "0.17.21",
"sqlalchemy": "2.0.19",
# "xformers": "0.0.16", # "xformers": "0.0.16",
} }
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"] modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]

View File

@ -0,0 +1,36 @@
from sqlalchemy.orm import Session
from easydiffusion import bucket_models, bucket_schemas
def get_bucket(db: Session, bucket_id: int):
return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.id == bucket_id).first()
def get_bucket_by_path(db: Session, path: str):
return db.query(bucket_models.Bucket).filter(bucket_models.Bucket.path == path).first()
def get_buckets(db: Session, skip: int = 0, limit: int = 100):
return db.query(bucket_models.Bucket).offset(skip).limit(limit).all()
def create_bucket(db: Session, bucket: bucket_schemas.BucketCreate):
db_bucket = bucket_models.Bucket(path=bucket.path)
db.add(db_bucket)
db.commit()
db.refresh(db_bucket)
return db_bucket
def get_bucketfiles(db: Session, skip: int = 0, limit: int = 100):
return db.query(bucket_models.BucketFile).offset(skip).limit(limit).all()
def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int):
db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id)
db.add(db_bucketfile)
db.commit()
db.refresh(db_bucketfile)
return db_bucketfile

View File

@ -0,0 +1,15 @@
import os
from easydiffusion import app
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
os.makedirs(app.BUCKET_DIR, exist_ok=True)
SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db")
print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL)
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
BucketBase = declarative_base()

View File

@ -1,14 +1,57 @@
import os from typing import List
from easydiffusion import app
from sqlalchemy import create_engine from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///"+os.path.join(app.BUCKET_DIR, "bucket.db") from easydiffusion import bucket_crud, bucket_models, bucket_schemas
print("## SQLALCHEMY_DATABASE_URL = ", SQLALCHEMY_DATABASE_URL) from easydiffusion.bucket_database import SessionLocal, engine
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() def init():
from easydiffusion.server import server_api
bucket_models.BucketBase.metadata.create_all(bind=engine)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@server_api.post("/buckets/", response_model=bucket_schemas.Bucket)
def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)):
db_bucket = bucket_crud.get_bucket_by_path(db, path=bucket.path)
if db_bucket:
raise HTTPException(status_code=400, detail="Bucket already exists")
return bucket_crud.create_bucket(db=db, bucket=bucket)
@server_api.get("/buckets/", response_model=List[bucket_schemas.Bucket])
def read_bucket(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
buckets = bucket_crud.get_buckets(db, skip=skip, limit=limit)
return buckets
@server_api.get("/buckets/{bucket_id}", response_model=bucket_schemas.Bucket)
def read_bucket(bucket_id: int, db: Session = Depends(get_db)):
db_bucket = bucket_crud.get_bucket(db, bucket_id=bucket_id)
if db_bucket is None:
raise HTTPException(status_code=404, detail="Bucket not found")
return db_bucket
@server_api.post("/buckets/{bucket_id}/items/", response_model=bucket_schemas.BucketFile)
def create_bucketfile_in_bucket(
bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db)
):
return bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
@server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile])
def read_bucketfiles(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit)
return bucketfiles

View File

@ -0,0 +1,25 @@
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, BLOB
from sqlalchemy.orm import relationship
from easydiffusion.bucket_database import BucketBase
class Bucket(BucketBase):
__tablename__ = "bucket"
id = Column(Integer, primary_key=True, index=True)
path = Column(String, unique=True, index=True)
bucketfiles = relationship("BucketFile", back_populates="bucket")
class BucketFile(BucketBase):
__tablename__ = "bucketfile"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String, index=True)
data = Column(BLOB, index=False)
bucket_id = Column(Integer, ForeignKey("bucket.id"))
bucket = relationship("Bucket", back_populates="bucketfiles")

View File

@ -0,0 +1,37 @@
from typing import List, Union
from pydantic import BaseModel
class BucketFileBase(BaseModel):
filename: str
data: bytes
class BucketFileCreate(BucketFileBase):
pass
class BucketFile(BucketFileBase):
id: int
bucket_id: int
class Config:
orm_mode = True
class BucketBase(BaseModel):
path: str
class BucketCreate(BucketBase):
pass
class Bucket(BucketBase):
id: int
bucketfiles: List[BucketFile] = []
class Config:
orm_mode = True

View File

@ -8,7 +8,7 @@ import os
import traceback import traceback
from typing import List, Union from typing import List, Union
from easydiffusion import app, model_manager, task_manager, bucket_manager from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
from easydiffusion.utils import log from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException

View File

@ -1,4 +1,4 @@
from easydiffusion import model_manager, app, server from easydiffusion import model_manager, app, server, bucket_manager
from easydiffusion.server import server_api # required for uvicorn from easydiffusion.server import server_api # required for uvicorn
server.init() server.init()
@ -7,6 +7,7 @@ server.init()
model_manager.init() model_manager.init()
app.init() app.init()
app.init_render_threads() app.init_render_threads()
bucket_manager.init()
# start the browser ui # start the browser ui
app.open_browser() app.open_browser()