import json
import traceback

import sys
import os

SCRIPT_DIR = os.getcwd()
print('started in ', SCRIPT_DIR)

SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))

CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')

OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder

from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import logging

from sd_internal import Request, Response

app = FastAPI()

model_loaded = False
model_is_loading = False

modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)

# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails']

app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")

# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
    session_id: str = "session"
    prompt: str = ""
    negative_prompt: str = ""
    init_image: str = None # base64
    mask: str = None # base64
    num_outputs: int = 1
    num_inference_steps: int = 50
    guidance_scale: float = 7.5
    width: int = 512
    height: int = 512
    seed: int = 42
    prompt_strength: float = 0.8
    sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
    # allow_nsfw: bool = False
    save_to_disk_path: str = None
    turbo: bool = True
    use_cpu: bool = False
    use_full_precision: bool = False
    use_face_correction: str = None # or "GFPGANv1.3"
    use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
    show_only_filtered_image: bool = False

    stream_progress_updates: bool = False
    stream_image_progress: bool = False

class SetAppConfigRequest(BaseModel):
    update_branch: str = "main"

@app.get('/')
def read_root():
    headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
    return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)

@app.get('/ping')
async def ping():
    global model_loaded, model_is_loading

    try:
        if model_loaded:
            return {'OK'}

        if model_is_loading:
            return {'ERROR'}

        model_is_loading = True

        from sd_internal import runtime

        custom_weight_path = os.path.join(SCRIPT_DIR, 'custom-model.ckpt')
        ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
        runtime.load_model_ckpt(ckpt_to_use=ckpt_to_use)

        model_loaded = True
        model_is_loading = False

        return {'OK'}
    except Exception as e:
        print(traceback.format_exc())
        return HTTPException(status_code=500, detail=str(e))

@app.post('/image')
def image(req : ImageRequest):
    from sd_internal import runtime

    r = Request()
    r.session_id = req.session_id
    r.prompt = req.prompt
    r.negative_prompt = req.negative_prompt
    r.init_image = req.init_image
    r.mask = req.mask
    r.num_outputs = req.num_outputs
    r.num_inference_steps = req.num_inference_steps
    r.guidance_scale = req.guidance_scale
    r.width = req.width
    r.height = req.height
    r.seed = req.seed
    r.prompt_strength = req.prompt_strength
    r.sampler = req.sampler
    # r.allow_nsfw = req.allow_nsfw
    r.turbo = req.turbo
    r.use_cpu = req.use_cpu
    r.use_full_precision = req.use_full_precision
    r.save_to_disk_path = req.save_to_disk_path
    r.use_upscale: str = req.use_upscale
    r.use_face_correction = req.use_face_correction
    r.show_only_filtered_image = req.show_only_filtered_image

    r.stream_progress_updates = True # the underlying implementation only supports streaming
    r.stream_image_progress = req.stream_image_progress

    try:
        if not req.stream_progress_updates:
            r.stream_image_progress = False

        res = runtime.mk_img(r)

        if req.stream_progress_updates:
            return StreamingResponse(res, media_type='application/json')
        else: # compatibility mode: buffer the streaming responses, and return the last one
            last_result = None

            for result in res:
                last_result = result

            return json.loads(last_result)
    except Exception as e:
        print(traceback.format_exc())
        return HTTPException(status_code=500, detail=str(e))

@app.get('/image/stop')
def stop():
    try:
        if model_is_loading:
            return {'ERROR'}

        from sd_internal import runtime
        runtime.stop_processing = True

        return {'OK'}
    except Exception as e:
        print(traceback.format_exc())
        return HTTPException(status_code=500, detail=str(e))

@app.get('/image/tmp/{session_id}/{img_id}')
def get_image(session_id, img_id):
    from sd_internal import runtime
    buf = runtime.temp_images[session_id + '/' + img_id]
    buf.seek(0)
    return StreamingResponse(buf, media_type='image/jpeg')

@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
    try:
        config = {
            'update_branch': req.update_branch
        }

        config_json_str = json.dumps(config)
        config_bat_str = f'@set update_branch={req.update_branch}'
        config_sh_str = f'export update_branch={req.update_branch}'

        config_json_path = os.path.join(CONFIG_DIR, 'config.json')
        config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
        config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')

        with open(config_json_path, 'w') as f:
            f.write(config_json_str)

        with open(config_bat_path, 'w') as f:
            f.write(config_bat_str)

        with open(config_sh_path, 'w') as f:
            f.write(config_sh_str)

        return {'OK'}
    except Exception as e:
        print(traceback.format_exc())
        return HTTPException(status_code=500, detail=str(e))

@app.get('/app_config')
def getAppConfig():
    try:
        config_json_path = os.path.join(CONFIG_DIR, 'config.json')

        if not os.path.exists(config_json_path):
            return HTTPException(status_code=500, detail="No config file")

        with open(config_json_path, 'r') as f:
            config_json_str = f.read()
            config = json.loads(config_json_str)
            return config
    except Exception as e:
        print(traceback.format_exc())
        return HTTPException(status_code=500, detail=str(e))

@app.get('/modifiers.json')
def read_modifiers():
    headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
    return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers)

@app.get('/output_dir')
def read_home_dir():
    return {outpath}

# don't log certain requests
class LogSuppressFilter(logging.Filter):
    def filter(self, record: logging.LogRecord) -> bool:
        path = record.getMessage()
        for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
            if path.find(prefix) != -1:
                return False

        return True

logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())

# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')