mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-21 18:31:28 +02:00
Bucket API
This commit is contained in:
parent
646aeac858
commit
97035a54ed
@ -26,6 +26,7 @@ modules_to_check = {
|
|||||||
"pycloudflared": "0.2.0",
|
"pycloudflared": "0.2.0",
|
||||||
"ruamel.yaml": "0.17.21",
|
"ruamel.yaml": "0.17.21",
|
||||||
"sqlalchemy": "2.0.19",
|
"sqlalchemy": "2.0.19",
|
||||||
|
"python-multipart": "0.0.6",
|
||||||
# "xformers": "0.0.16",
|
# "xformers": "0.0.16",
|
||||||
}
|
}
|
||||||
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]
|
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit"]
|
||||||
|
@ -29,8 +29,9 @@ def get_bucketfiles(db: Session, skip: int = 0, limit: int = 100):
|
|||||||
|
|
||||||
def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int):
|
def create_bucketfile(db: Session, bucketfile: bucket_schemas.BucketFileCreate, bucket_id: int):
|
||||||
db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id)
|
db_bucketfile = bucket_models.BucketFile(**bucketfile.dict(), bucket_id=bucket_id)
|
||||||
db.add(db_bucketfile)
|
db.merge(db_bucketfile)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_bucketfile)
|
from pprint import pprint
|
||||||
|
db_bucketfile = db.query(bucket_models.BucketFile).filter(bucket_models.BucketFile.bucket_id==bucket_id, bucket_models.BucketFile.filename==bucketfile.filename).first()
|
||||||
return db_bucketfile
|
return db_bucketfile
|
||||||
|
|
||||||
|
@ -1,11 +1,31 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, HTTPException, Response, File
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from easydiffusion import bucket_crud, bucket_models, bucket_schemas
|
from easydiffusion import bucket_crud, bucket_models, bucket_schemas
|
||||||
from easydiffusion.bucket_database import SessionLocal, engine
|
from easydiffusion.bucket_database import SessionLocal, engine
|
||||||
|
|
||||||
|
from requests.compat import urlparse
|
||||||
|
|
||||||
|
import base64, json
|
||||||
|
|
||||||
|
MIME_TYPES = {
|
||||||
|
"jpg": "image/jpeg",
|
||||||
|
"jpeg": "image/jpeg",
|
||||||
|
"gif": "image/gif",
|
||||||
|
"png": "image/png",
|
||||||
|
"webp": "image/webp",
|
||||||
|
"js": "text/javascript",
|
||||||
|
"htm": "text/html",
|
||||||
|
"html": "text/html",
|
||||||
|
"css": "text/css",
|
||||||
|
"json": "application/json",
|
||||||
|
"mjs": "application/json",
|
||||||
|
"yaml": "application/yaml",
|
||||||
|
"svg": "image/svg+xml",
|
||||||
|
"txt": "text/plain",
|
||||||
|
}
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
from easydiffusion.server import server_api
|
from easydiffusion.server import server_api
|
||||||
@ -21,6 +41,42 @@ def init():
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
@server_api.get("/bucket/{obj_path:path}")
|
||||||
|
def bucket_get_object(obj_path: str, db: Session = Depends(get_db)):
|
||||||
|
filename = get_filename_from_url(obj_path)
|
||||||
|
path = get_path_from_url(obj_path)
|
||||||
|
|
||||||
|
if filename==None:
|
||||||
|
bucket = bucket_crud.get_bucket_by_path(db, path=path)
|
||||||
|
if bucket == None:
|
||||||
|
raise HTTPException(status_code=404, detail="Bucket not found")
|
||||||
|
bucketfiles = db.query(bucket_models.BucketFile).with_entities(bucket_models.BucketFile.filename).filter(bucket_models.BucketFile.bucket_id == bucket.id).all()
|
||||||
|
bucketfiles = [ x.filename for x in bucketfiles ]
|
||||||
|
return bucketfiles
|
||||||
|
|
||||||
|
else:
|
||||||
|
bucket_id = bucket_crud.get_bucket_by_path(db, path).id
|
||||||
|
bucketfile = db.query(bucket_models.BucketFile).filter(bucket_models.BucketFile.bucket_id == bucket_id, bucket_models.BucketFile.filename == filename).first()
|
||||||
|
|
||||||
|
suffix = get_suffix_from_filename(filename)
|
||||||
|
|
||||||
|
return Response(content=bucketfile.data, media_type=MIME_TYPES.get(suffix, "application/octet-stream"))
|
||||||
|
|
||||||
|
@server_api.post("/bucket/{obj_path:path}")
|
||||||
|
def bucket_post_object(obj_path: str, file: bytes = File(), db: Session = Depends(get_db)):
|
||||||
|
filename = get_filename_from_url(obj_path)
|
||||||
|
path = get_path_from_url(obj_path)
|
||||||
|
bucket = bucket_crud.get_bucket_by_path(db, path)
|
||||||
|
|
||||||
|
if bucket == None:
|
||||||
|
bucket_id = bucket_crud.create_bucket(db=db, bucket=bucket_schemas.BucketCreate(path=path))
|
||||||
|
else:
|
||||||
|
bucket_id = bucket.id
|
||||||
|
|
||||||
|
bucketfile = bucket_schemas.BucketFileCreate(filename=filename, data=file)
|
||||||
|
result = bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
|
||||||
|
result.data = base64.encodestring(result.data)
|
||||||
|
return result
|
||||||
|
|
||||||
@server_api.post("/buckets/", response_model=bucket_schemas.Bucket)
|
@server_api.post("/buckets/", response_model=bucket_schemas.Bucket)
|
||||||
def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)):
|
def create_bucket(bucket: bucket_schemas.BucketCreate, db: Session = Depends(get_db)):
|
||||||
@ -47,7 +103,10 @@ def init():
|
|||||||
def create_bucketfile_in_bucket(
|
def create_bucketfile_in_bucket(
|
||||||
bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db)
|
bucket_id: int, bucketfile: bucket_schemas.BucketFileCreate, db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
return bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
|
bucketfile.data = base64.decodestring(bucketfile.data)
|
||||||
|
result = bucket_crud.create_bucketfile(db=db, bucketfile=bucketfile, bucket_id=bucket_id)
|
||||||
|
result.data = base64.encodestring(result.data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile])
|
@server_api.get("/bucketfiles/", response_model=List[bucket_schemas.BucketFile])
|
||||||
@ -55,3 +114,16 @@ def init():
|
|||||||
bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit)
|
bucketfiles = bucket_crud.get_bucketfiles(db, skip=skip, limit=limit)
|
||||||
return bucketfiles
|
return bucketfiles
|
||||||
|
|
||||||
|
|
||||||
|
def get_filename_from_url(url):
|
||||||
|
path = urlparse(url).path
|
||||||
|
name = path[path.rfind('/')+1:]
|
||||||
|
return name or None
|
||||||
|
|
||||||
|
def get_path_from_url(url):
|
||||||
|
path = urlparse(url).path
|
||||||
|
path = path[0:path.rfind('/')]
|
||||||
|
return path or None
|
||||||
|
|
||||||
|
def get_suffix_from_filename(filename):
|
||||||
|
return filename[filename.rfind('.')+1:]
|
||||||
|
@ -16,10 +16,10 @@ class Bucket(BucketBase):
|
|||||||
class BucketFile(BucketBase):
|
class BucketFile(BucketBase):
|
||||||
__tablename__ = "bucketfile"
|
__tablename__ = "bucketfile"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
filename = Column(String, index=True, primary_key=True)
|
||||||
filename = Column(String, index=True)
|
bucket_id = Column(Integer, ForeignKey("bucket.id"), primary_key=True)
|
||||||
|
|
||||||
data = Column(BLOB, index=False)
|
data = Column(BLOB, index=False)
|
||||||
bucket_id = Column(Integer, ForeignKey("bucket.id"))
|
|
||||||
|
|
||||||
bucket = relationship("Bucket", back_populates="bucketfiles")
|
bucket = relationship("Bucket", back_populates="bucketfiles")
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ class BucketFileCreate(BucketFileBase):
|
|||||||
|
|
||||||
|
|
||||||
class BucketFile(BucketFileBase):
|
class BucketFile(BucketFileBase):
|
||||||
id: int
|
|
||||||
bucket_id: int
|
bucket_id: int
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user