easydiffusion/main.py

66 lines
1.6 KiB
Python
Raw Normal View History

2022-08-25 18:16:31 +02:00
from fastapi import FastAPI, HTTPException
2022-08-23 22:28:18 +02:00
from starlette.responses import FileResponse
from pydantic import BaseModel
import requests
2022-08-24 15:51:56 +02:00
LOCAL_SERVER_URL = 'http://stability-ai:5000'
2022-08-23 22:28:18 +02:00
PREDICT_URL = LOCAL_SERVER_URL + '/predictions'
app = FastAPI()
# defaults from https://huggingface.co/blog/stable_diffusion
2022-08-23 22:28:18 +02:00
class ImageRequest(BaseModel):
prompt: str
2022-08-25 18:16:31 +02:00
init_image: str = None # base64
num_outputs: str = "1"
num_inference_steps: str = "50"
guidance_scale: str = "7.5"
width: str = "512"
height: str = "512"
seed: str = "30000"
2022-08-25 19:26:52 +02:00
prompt_strength: str = "0.8"
2022-08-23 22:28:18 +02:00
@app.get('/')
def read_root():
return FileResponse('index.html')
@app.get('/ping')
async def ping():
try:
requests.get(LOCAL_SERVER_URL)
return {'OK'}
except:
return {'ERROR'}
@app.post('/image')
async def image(req : ImageRequest):
data = {
2022-08-23 22:28:18 +02:00
"input": {
"prompt": req.prompt,
"num_outputs": req.num_outputs,
"num_inference_steps": req.num_inference_steps,
"width": req.width,
"height": req.height,
"seed": req.seed,
"guidance_scale": req.guidance_scale,
2022-08-23 22:28:18 +02:00
}
}
2022-08-25 18:16:31 +02:00
if req.init_image is not None:
data['input']['init_image'] = req.init_image
2022-08-25 19:26:52 +02:00
data['input']['prompt_strength'] = req.prompt_strength
2022-08-25 18:16:31 +02:00
if req.seed == "-1":
del data['input']['seed']
res = requests.post(PREDICT_URL, json=data)
2022-08-25 18:16:31 +02:00
if res.status_code != 200:
raise HTTPException(status_code=500, detail=res.text)
2022-08-23 22:28:18 +02:00
return res.json()
2022-08-24 17:42:42 +02:00
@app.get('/media/ding.mp3')
2022-08-23 22:28:18 +02:00
def read_root():
2022-08-24 17:42:42 +02:00
return FileResponse('media/ding.mp3')