From 97035a54edccf28c5a8fb026c93e5ea4b3da7cf3 Mon Sep 17 00:00:00 2001 From: JeLuF Date: Thu, 27 Jul 2023 21:56:36 +0200 Subject: [PATCH] Bucket API --- scripts/check_modules.py | 1 + ui/easydiffusion/bucket_crud.py | 5 +- ui/easydiffusion/bucket_manager.py | 76 +++++++++++++++++++++++++++++- ui/easydiffusion/bucket_models.py | 6 +-- ui/easydiffusion/bucket_schemas.py | 1 - 5 files changed, 81 insertions(+), 8 deletions(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 0f50e842..c2713c64 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -26,6 +26,7 @@ modules_to_check = { "pycloudflared": "0.2.0", "ruamel.yaml": "0.17.21", "sqlalchemy": "2.0.19", + "python-multipart": "0.0.6", # "xformers": "0.0.16", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"] diff --git a/ui/easydiffusion/bucket_crud.py b/ui/easydiffusion/bucket_crud.py index ece14fcc..5dada176 100644 --- a/ui/easydiffusion/bucket_crud.py +++ b/ui/easydiffusion/bucket_crud.py @@ -29,8 +29,9 @@ def get_bucketfiles(db: Session, skip: int = 0, limit: int = 100): def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int): db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id) - db.add(db_bucketfile) + db.merge(db_bucketfile) db.commit() - db.refresh(db_bucketfile) + from pprint import pprint + db_bucketfile = db.query(bucket_models.BucketFile).filter(bucket_models.BucketFile.bucket_id==bucket_id, bucket_models.BucketFile.filename==bucketfile.filename).first() return db_bucketfile diff --git a/ui/easydiffusion/bucket_manager.py b/ui/easydiffusion/bucket_manager.py index 0985573f..58d0e30c 100644 --- a/ui/easydiffusion/bucket_manager.py +++ b/ui/easydiffusion/bucket_manager.py @@ -1,11 +1,31 @@ from typing import List -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Response, File from sqlalchemy.orm import Session from easydiffusion import bucket_crud, bucket_models, bucket_schemas from easydiffusion.bucket_database import SessionLocal, engine +from requests.compat import urlparse + +import base64, json + +MIME_TYPES = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "gif": "image/gif", + "png": "image/png", + "webp": "image/webp", + "js": "text/javascript", + "htm": "text/html", + "html": "text/html", + "css": "text/css", + "json": "application/json", + "mjs": "application/json", + "yaml": "application/yaml", + "svg": "image/svg+xml", + "txt": "text/plain", +} def init(): from easydiffusion.server import server_api @@ -21,6 +41,42 @@ def init(): finally: db.close() + @server_api.get("/bucket/{obj_path:path}") + def bucket_get_object(obj_path: str, db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + + if filename==None: + bucket = bucket_crud.get_bucket_by_path(db, path=path) + if bucket == None: + raise HTTPException(status_code=404, detail="Bucket not found") + bucketfiles = db.query(bucket_models.BucketFile).with_entities(bucket_models.BucketFile.filename).filter(bucket_models.BucketFile.bucket_id == bucket.id).all() + bucketfiles = [ x.filename for x in bucketfiles ] + return bucketfiles + + else: + bucket_id = bucket_crud.get_bucket_by_path(db, path).id + bucketfile = db.query(bucket_models.BucketFile).filter(bucket_models.BucketFile.bucket_id == bucket_id, bucket_models.BucketFile.filename == filename).first() + + suffix = get_suffix_from_filename(filename) + + return Response(content=bucketfile.data, media_type=MIME_TYPES.get(suffix, "application/octet-stream")) + + @server_api.post("/bucket/{obj_path:path}") + def bucket_post_object(obj_path: str, file: bytes = File(), db: Session = Depends(get_db)): + filename = get_filename_from_url(obj_path) + path = get_path_from_url(obj_path) + bucket = bucket_crud.get_bucket_by_path(db, path) + + if bucket == None: + bucket_id = bucket_crud.create_bucket(db=db, bucket=bucket_schemas.BucketCreate(path=path)) + else: + bucket_id = bucket.id + + bucketfile = bucket_schemas.BucketFileCreate(filename=filename, data=file) + result = bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result @server_api.post("/buckets/", response_model=bucket_schemas.Bucket) def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)): @@ -47,7 +103,10 @@ def init(): def create_bucketfile_in_bucket( bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db) ): - return bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + bucketfile.data = base64.decodestring(bucketfile.data) + result = bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id) + result.data = base64.encodestring(result.data) + return result @server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile]) @@ -55,3 +114,16 @@ def init(): bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit) return bucketfiles + +def get_filename_from_url(url): + path = urlparse(url).path + name = path[path.rfind('/')+1:] + return name or None + +def get_path_from_url(url): + path = urlparse(url).path + path = path[0:path.rfind('/')] + return path or None + +def get_suffix_from_filename(filename): + return filename[filename.rfind('.')+1:] diff --git a/ui/easydiffusion/bucket_models.py b/ui/easydiffusion/bucket_models.py index 3ede1a36..933f8140 100644 --- a/ui/easydiffusion/bucket_models.py +++ b/ui/easydiffusion/bucket_models.py @@ -16,10 +16,10 @@ class Bucket(BucketBase): class BucketFile(BucketBase): __tablename__ = "bucketfile" - id = Column(Integer, primary_key=True, index=True) - filename = Column(String, index=True) + filename = Column(String, index=True, primary_key=True) + bucket_id = Column(Integer, ForeignKey("bucket.id"), primary_key=True) + data = Column(BLOB, index=False) - bucket_id = Column(Integer, ForeignKey("bucket.id")) bucket = relationship("Bucket", back_populates="bucketfiles") diff --git a/ui/easydiffusion/bucket_schemas.py b/ui/easydiffusion/bucket_schemas.py index 9f7e2377..68bc04e2 100644 --- a/ui/easydiffusion/bucket_schemas.py +++ b/ui/easydiffusion/bucket_schemas.py @@ -13,7 +13,6 @@ class BucketFileCreate(BucketFileBase): class BucketFile(BucketFileBase): - id: int bucket_id: int class Config: