2022-08-23 22:28:18 +02:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
LOCAL_SERVER_URL = 'http://localhost:5000'
|
|
|
|
PREDICT_URL = LOCAL_SERVER_URL + '/predictions'
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
2022-08-24 09:03:35 +02:00
|
|
|
# defaults from https://huggingface.co/blog/stable_diffusion
|
2022-08-23 22:28:18 +02:00
|
|
|
class ImageRequest(BaseModel):
|
|
|
|
prompt: str
|
2022-08-24 09:03:35 +02:00
|
|
|
num_outputs: str = "1"
|
|
|
|
num_inference_steps: str = "50"
|
|
|
|
guidance_scale: str = "7.5"
|
|
|
|
width: str = "512"
|
|
|
|
height: str = "512"
|
|
|
|
seed: str = "30000"
|
2022-08-23 22:28:18 +02:00
|
|
|
|
|
|
|
@app.get('/')
|
|
|
|
def read_root():
|
|
|
|
return FileResponse('index.html')
|
|
|
|
|
|
|
|
@app.get('/ping')
|
|
|
|
async def ping():
|
|
|
|
try:
|
|
|
|
requests.get(LOCAL_SERVER_URL)
|
|
|
|
return {'OK'}
|
|
|
|
except:
|
|
|
|
return {'ERROR'}
|
|
|
|
|
|
|
|
@app.post('/image')
|
|
|
|
async def image(req : ImageRequest):
|
2022-08-24 09:03:35 +02:00
|
|
|
data = {
|
2022-08-23 22:28:18 +02:00
|
|
|
"input": {
|
|
|
|
"prompt": req.prompt,
|
2022-08-24 09:03:35 +02:00
|
|
|
"num_outputs": req.num_outputs,
|
|
|
|
"num_inference_steps": req.num_inference_steps,
|
|
|
|
"width": req.width,
|
|
|
|
"height": req.height,
|
|
|
|
"seed": req.seed,
|
2022-08-24 09:13:50 +02:00
|
|
|
"guidance_scale": req.guidance_scale,
|
2022-08-23 22:28:18 +02:00
|
|
|
}
|
2022-08-24 09:03:35 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
if req.seed == "-1":
|
|
|
|
del data['input']['seed']
|
|
|
|
|
|
|
|
res = requests.post(PREDICT_URL, json=data)
|
2022-08-23 22:28:18 +02:00
|
|
|
return res.json()
|
|
|
|
|
|
|
|
@app.get('/ding.mp3')
|
|
|
|
def read_root():
|
|
|
|
return FileResponse('ding.mp3')
|