workable plan

This commit is contained in:
caranicas 2022-09-24 13:47:30 -04:00
parent d1e792686e
commit 22a0b3be45
11 changed files with 370 additions and 288 deletions

View File

@ -1,28 +0,0 @@
import { useInfiniteQuery } from "@tanstack/react-query";
export const useAyncGeneratorQuery = (key: string, asyncGeneratorFn: () => Generator<string[], string[], unknown>) => {
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const customStatus = React.useRef<Status>('success');
React.useEffect(() => {
(async function doReceive() {
try {
for await (let data of asyncGeneratorFn()) {
fetchMore({ eventData: data });
}
} catch (e) {
customStatus.current = 'error';
}
})();
}, [asyncGeneratorFn, fetchMore])
return { status: customStatus.current, data };
}

View File

@ -1,21 +0,0 @@
type CallbackType = (data: string[]) => void;
type InitCallbackType = (cb: CallbackType) => void;
export const useCallbackQuery = (key: string, initCallbackQuery: InitCallbackType) => {
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const callback = React.useCallback((data) => { fetchMore({ eventData: data }) }, [fetchMore]);
React.useEffect(() => { initCallbackQuery(callback); }, [callback, initCallbackQuery])
const customStatus = React.useRef<Status>('success');
return { status: customStatus.current, data };
}

View File

@ -1,19 +0,0 @@
import { useQuery, useQueryClient } from 'react-query'
export const useEventSourceQuery = (queryKey, url, eventName) => {
const queryClient = useQueryClient()
const fetchData = () => {
const evtSource = new EventSource(url)
evtSource.addEventListener(eventName, (event) => {
const eventData = event.data && JSON.parse(event.data)
if (eventData) {
queryClient.setQueryData(queryKey, eventData)
}
})
}
return useQuery(queryKey, fetchData)
}

View File

@ -1,47 +0,0 @@
import * as React from 'react';
import { useInfiniteQuery } from 'react-query';
interface CustomQueryParams {
eventData: string[];
}
type Status = "success" | "loading" | "error";
export const useEventSourceQuery = (key: string, url: string, eventName: string) => {
const eventSource = React.useRef<EventSource>(new EventSource(url));
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const customStatus = React.useRef<Status>('success');
React.useEffect(() => {
const evtSource = eventSource.current;
const onEvent = function (ev: MessageEvent | Event) {
if (!e.data) {
return;
}
// Let's assume here we receive multiple data, ie. e.data is an array.
fetchMore({ eventData: e.data });
}
const onError = () => { customStatus.current = 'error' };
evtSource.addEventListener(eventName, onEvent);
evtSource.addEventListener('error', onError);
return () => {
evtSource.removeEventListener(eventName, onEvent);
evtSource.removeEventListener('error', onError);
}
}, [url, eventName, fetchMore])
return { status: customStatus.current, data };
}

View File

@ -79,7 +79,9 @@ export const doMakeImage = async (reqBody: ImageRequest) => {
},
body: JSON.stringify(reqBody),
});
console.log('doMakeImage= GOT RESPONSE', res);
const data = await res.json();
return data;
// const data = await res.json();
// return data;
return res;
};

View File

@ -1,5 +1,5 @@
import React from "react";
import { API_URL } from "../../../../api";
import { API_URL } from "../../../../../../api";
const url = `${API_URL}/ding.mp3`;

View File

@ -1,49 +1,108 @@
/* eslint-disable @typescript-eslint/no-unnecessary-type-assertion */
/* eslint-disable @typescript-eslint/prefer-ts-expect-error */
/* eslint-disable @typescript-eslint/naming-convention */
import React from "react";
import React, { useEffect } from "react";
import { useImageCreate, ImageRequest } from "../../../../../stores/imageCreateStore";
import { useImageQueue } from "../../../../../stores/imageQueueStore";
import {
FetchingStates,
useImageFetching
} from "../../../../../stores/imageFetchingStore";
import { v4 as uuidv4 } from "uuid";
import { useRandomSeed } from "../../../../../utils";
import { doMakeImage } from "../../../../../api";
import {
MakeButtonStyle, // @ts-expect-error
} from "./makeButton.css.ts";
import { useTranslation } from "react-i18next";
import AudioDing from "./audioDing";
import { parse } from "node:path/win32";
export default function MakeButton() {
const { t } = useTranslation();
const parallelCount = useImageCreate((state) => state.parallelCount);
const builtRequest = useImageCreate((state) => state.builtRequest);
const addNewImage = useImageQueue((state) => state.addNewImage);
const hasQueue = useImageQueue((state) => state.hasQueuedImages());
const isRandomSeed = useImageCreate((state) => state.isRandomSeed());
const setRequestOption = useImageCreate((state) => state.setRequestOptions);
// const makeImages = () => {
// // potentially update the seed
// if (isRandomSeed) {
// // update the seed for the next time we click the button
// setRequestOption("seed", useRandomSeed());
// }
const addNewImage = useImageQueue((state) => state.addNewImage);
const hasQueue = useImageQueue((state) => state.hasQueuedImages());
const { id, options } = useImageQueue((state) => state.firstInQueue());
// // the request that we have built
// const req = builtRequest();
const setStatus = useImageFetching((state) => state.setStatus);
const appendData = useImageFetching((state) => state.appendData);
// };
const parseRequest = async (reader: ReadableStreamDefaultReader<Uint8Array>) => {
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
const queueImageRequest = (req: ImageRequest) => {
if (done as boolean) {
console.log("DONE");
setStatus(FetchingStates.COMPLETE);
break;
}
const jsonStr = decoder.decode(value);
try {
const update = JSON.parse(jsonStr);
if (update.status === "progress") {
console.log("PROGRESS");
setStatus(FetchingStates.PROGRESSING);
}
else if (update.status === "succeeded") {
console.log("succeeded");
setStatus(FetchingStates.SUCCEEDED);
// appendData(update.data);
}
else {
console.log("extra?", update);
// appendData(update.data);
}
}
catch (e) {
console.log('PARSE ERRROR')
console.log(e)
debugger;
// appendData(update.data);
}
}
}
const startStream = async (req: ImageRequest) => {
const streamReq = {
...req,
// stream_image_progress: false,
};
console.log("testStream", streamReq);
try {
const res = await doMakeImage(streamReq);
// @ts-expect-error
const reader = res.body.getReader();
void parseRequest(reader);
} catch (e) {
console.log('e');
}
}
const queueImageRequest = async (req: ImageRequest) => {
// the actual number of request we will make
const requests = [];
// the number of images we will make
let { num_outputs } = req;
// if making fewer images than the parallel count
// then it is only 1 request
if (parallelCount > num_outputs) {
requests.push(num_outputs);
} else {
@ -63,7 +122,6 @@ export default function MakeButton() {
}
}
// make the requests
requests.forEach((num, index) => {
// get the seed we want to use
let seed = req.seed;
@ -82,63 +140,42 @@ export default function MakeButton() {
});
}
const testStream = async (req: ImageRequest) => {
const streamReq = {
...req,
stream_progress_updates: true,
// stream_image_progress: false,
session_id: uuidv4(),
};
console.log("testStream", streamReq);
try {
const res = await fetch('http://localhost:9000/image', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(streamReq)
});
console.log('res', res);
const reader = res.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
console.log(text);
}
} catch (e) {
console.log(e);
debugger;
}
}
const makeImages = () => {
const makeImageQueue = async () => {
// potentially update the seed
if (isRandomSeed) {
// update the seed for the next time we click the button
setRequestOption("seed", useRandomSeed());
}
// the request that we have built
const req = builtRequest();
//queueImageRequest(req);
void testStream(req);
await queueImageRequest(req);
// void startStream(req);
};
useEffect(() => {
const makeImages = async (options: ImageRequest) => {
// potentially update the seed
await startStream(options);
}
if (hasQueue) {
makeImages(options).catch((e) => {
console.log('HAS QUEUE ERROR');
console.log(e);
});
}
}, [hasQueue, id, options, startStream]);
return (
<button
className={MakeButtonStyle}
onClick={makeImages}
onClick={() => {
setStatus(FetchingStates.FETCHING);
void makeImageQueue();
}}
disabled={hasQueue}
>
{t("home.make-img-btn")}

View File

@ -1,18 +1,24 @@
/* eslint-disable @typescript-eslint/strict-boolean-expressions */
import React, { useEffect, useState, useRef } from "react";
import { useImageQueue } from "../../../stores/imageQueueStore";
import { ImageRequest, useImageCreate } from "../../../stores/imageCreateStore";
import { useImageFetching } from "../../../stores/imageFetchingStore";
import { useImageDisplay } from "../../../stores/imageDisplayStore";
import { useQuery, useQueryClient } from "@tanstack/react-query";
// import { ImageRequest, useImageCreate } from "../../../stores/imageCreateStore";
import {
doMakeImage,
MakeImageKey,
ImageReturnType,
ImageOutput,
} from "../../../api";
// import { useQuery, useQueryClient } from "@tanstack/react-query";
import AudioDing from "./audioDing";
// import {
// API_URL,
// doMakeImage,
// MakeImageKey,
// ImageReturnType,
// ImageOutput,
// } from "../../../api";
// import AudioDing from "../creationPanel/basicCreation/makeButton/audioDing";
// import GeneratedImage from "../../molecules/generatedImage";
// import DrawImage from "../../molecules/drawImage";
@ -22,132 +28,207 @@ import CompletedImages from "./completedImages";
import {
displayPanel,
displayContainer,
previousImages,
// displayContainer,
// previousImages,
// @ts-expect-error
} from "./displayPanel.css.ts";
export interface CompletedImagesType {
id: string;
data: string;
info: ImageRequest;
}
// export interface CompletedImagesType {
// id: string;
// data: string;
// info: ImageRequest;
// }
const idDelim = "_batch";
export default function DisplayPanel() {
const dingRef = useRef<HTMLAudioElement>(null);
const isSoundEnabled = useImageCreate((state) => state.isSoundEnabled());
// const dingRef = useRef<HTMLAudioElement>(null);
// const isSoundEnabled = useImageCreate((state) => state.isSoundEnabled());
// @ts-expect-error
const { id, options } = useImageQueue((state) => state.firstInQueue());
const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
// // @ts-expect-error
// const { id, options } = useImageQueue((state) => state.firstInQueue());
// const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
const [currentImage, setCurrentImage] = useState<CompletedImagesType | null>(
null
);
// const [currentImage, setCurrentImage] = useState<CompletedImagesType | null>(
// null
// );
const [isEnabled, setIsEnabled] = useState(false);
// const [isEnabled, setIsEnabled] = useState(false);
const [isLoading, setIsLoading] = useState(true);
// const [isLoading, setIsLoading] = useState(true);
const { status, data } = useQuery(
[MakeImageKey, id],
async () => await doMakeImage(options),
{
enabled: isEnabled,
}
);
// const { status, data } = useQuery(
// [MakeImageKey, id],
// async () => await doMakeImage(options),
// {
// enabled: isEnabled,
// }
// );
// const { status, data } = useEventSourceQuery(
// MakeImageKey,
// // [MakeImageKey, id],
// // async () => await doMakeImage(options),
// // {
// // enabled: isEnabled,
// // }
// );
// update the enabled state when the id changes
useEffect(() => {
setIsEnabled(void 0 !== id);
}, [id]);
// useEffect(() => {
// setIsEnabled(void 0 !== id);
// }, [id]);
// const _handleStreamData = async (res: typeof ReadableStream) => {
// console.log("_handleStreamData");
// let reader;
// // @ts-expect-error
// if (res.body.locked) {
// console.log("locked");
// }
// else {
// reader = res.body.getReader();
// }
// console.log("reader", reader);
// const decoder = new TextDecoder();
// while (true) {
// const { done, value } = await reader.read();
// const text = decoder.decode(value);
// console.log("DECODE", done);
// console.log(text);
// if (text.status === "progress") {
// console.log("PROGRESS");
// }
// else if (text.status === "succeeded") {
// console.log("succeeded");
// }
// else {
// console.log("extra?")
// }
// console.log("-----------------");
// if (done as boolean) {
// reader.releaseLock();
// break;
// }
// }
// };
// useEffect(() => {
// const fetch = async () => {
// const res = await doMakeImage(options);
// void _handleStreamData(res);
// }
// if (isEnabled) {
// console.log('isEnabled');
// debugger;
// fetch()
// .catch((err) => {
// console.error(err);
// });
// }
// }, [isEnabled, options, _handleStreamData]);
// helper for the loading state to be enabled aware
useEffect(() => {
if (isEnabled && status === "loading") {
setIsLoading(true);
} else {
setIsLoading(false);
}
}, [isEnabled, status]);
// useEffect(() => {
// if (isEnabled && status === "loading") {
// setIsLoading(true);
// } else {
// setIsLoading(false);
// }
// }, [isEnabled, status]);
// this is where there loading actually happens
useEffect(() => {
// query is done
if (status === "success") {
// check to make sure that the image was created
if (data.status === "succeeded") {
if (isSoundEnabled) {
// not awaiting the promise or error handling
void dingRef.current?.play();
}
removeFirstInQueue();
}
}
}, [status, data, removeFirstInQueue, dingRef, isSoundEnabled]);
// useEffect(() => {
// console.log('DISPLATPANEL: status', status);
// console.log('DISPLATPANEL: data', data);
// // query is done
// if (status === "success") {
// // check to make sure that the image was created
// void _handleStreamData(data);
// // if (data.status === "succeeded") {
// // if (isSoundEnabled) {
// // // not awaiting the promise or error handling
// // void dingRef.current?.play();
// // }
// // removeFirstInQueue();
// // }
// }
// }, [status, data, removeFirstInQueue, dingRef, isSoundEnabled, _handleStreamData]);
/* COMPLETED IMAGES */
const queryClient = useQueryClient();
const [completedImages, setCompletedImages] = useState<CompletedImagesType[]>(
[]
);
// const queryClient = useQueryClient();
// const [completedImages, setCompletedImages] = useState<CompletedImagesType[]>(
// []
// );
const completedIds = useImageQueue((state) => state.completedImageIds);
const clearCachedIds = useImageQueue((state) => state.clearCachedIds);
// const completedIds = useImageQueue((state) => state.completedImageIds);
// const clearCachedIds = useImageQueue((state) => state.clearCachedIds);
// this is where we generate the list of completed images
useEffect(() => {
const completedQueries = completedIds.map((id) => {
const imageData = queryClient.getQueryData([MakeImageKey, id]);
return imageData;
}) as ImageReturnType[];
// useEffect(() => {
// const completedQueries = completedIds.map((id) => {
// const imageData = queryClient.getQueryData([MakeImageKey, id]);
// return imageData;
// }) as ImageReturnType[];
if (completedQueries.length > 0) {
// map the completedImagesto a new array
// and then set the state
const temp = completedQueries
.map((query, index) => {
if (void 0 !== query) {
return query.output.map((data: ImageOutput, index: number) => {
return {
id: `${completedIds[index]}${idDelim}-${data.seed}-${index}`,
data: data.data,
info: { ...query.request, seed: data.seed },
};
});
}
})
.flat()
.reverse()
.filter((item) => void 0 !== item) as CompletedImagesType[]; // remove undefined items
// if (completedQueries.length > 0) {
// // map the completedImagesto a new array
// // and then set the state
// const temp = completedQueries
// .map((query, index) => {
// if (void 0 !== query) {
// return query.output.map((data: ImageOutput, index: number) => {
// return {
// id: `${completedIds[index]}${idDelim}-${data.seed}-${index}`,
// data: data.data,
// info: { ...query.request, seed: data.seed },
// };
// });
// }
// })
// .flat()
// .reverse()
// .filter((item) => void 0 !== item) as CompletedImagesType[]; // remove undefined items
setCompletedImages(temp);
// setCompletedImages(temp);
// could move this to the useEffect for completedImages
if (temp.length > 0) {
setCurrentImage(temp[0]);
} else {
setCurrentImage(null);
}
} else {
setCompletedImages([]);
setCurrentImage(null);
}
}, [setCompletedImages, setCurrentImage, queryClient, completedIds]);
// // could move this to the useEffect for completedImages
// if (temp.length > 0) {
// setCurrentImage(temp[0]);
// } else {
// setCurrentImage(null);
// }
// } else {
// setCompletedImages([]);
// setCurrentImage(null);
// }
// }, [setCompletedImages, setCurrentImage, queryClient, completedIds]);
// this is how we remove them
const removeImages = () => {
completedIds.forEach((id) => {
queryClient.removeQueries([MakeImageKey, id]);
});
clearCachedIds();
};
// const removeImages = () => {
// completedIds.forEach((id) => {
// queryClient.removeQueries([MakeImageKey, id]);
// });
// clearCachedIds();
// };
return (
<div className={displayPanel}>
<AudioDing ref={dingRef}></AudioDing>
DISPLAY
{/* <AudioDing ref={dingRef}></AudioDing>
<div className={displayContainer}>
<CurrentDisplay
isLoading={isLoading}
@ -160,7 +241,7 @@ export default function DisplayPanel() {
images={completedImages}
setCurrentDisplay={setCurrentImage}
></CompletedImages>
</div>
</div> */}
</div>
);
}

View File

@ -70,7 +70,7 @@ export interface ImageRequest {
init_image: undefined | string;
prompt_strength: undefined | number;
sampler: typeof SAMPLER_OPTIONS[number];
stream_progress_updates: boolean;
stream_progress_updates: true;
stream_image_progress: boolean;
}
@ -136,7 +136,7 @@ export const useImageCreate = create<ImageCreateState>(
show_only_filtered_image: true,
init_image: undefined,
sampler: "plms",
stream_progress_updates: false,
stream_progress_updates: true,
stream_image_progress: false
},

View File

@ -0,0 +1,71 @@
import create from "zustand";
import produce from "immer";
export const FetchingStates = {
IDLE: "IDLE",
FETCHING: "FETCHING",
PROGRESSING: "PROGRESSING",
SUCCEEDED: "SUCCEEDED",
COMPLETE: "COMPLETE",
ERROR: "ERROR",
} as const;
interface ImageFetchingState {
status: typeof FetchingStates[keyof typeof FetchingStates];
step: number;
totalSteps: number;
data: string;
appendData: (data: string) => void;
reset: () => void;
setStatus: (status: typeof FetchingStates[keyof typeof FetchingStates]) => void;
setStep: (step: number) => void;
setTotalSteps: (totalSteps: number) => void;
}
export const useImageFetching = create<ImageFetchingState>((set) => ({
status: FetchingStates.IDLE,
step: 0,
totalSteps: 0,
data: '',
// use produce to make sure we don't mutate state
appendData: (data: string) => {
set(
produce((state: ImageFetchingState) => {
// eslint-disable-next-line @typescript-eslint/restrict-plus-operands
state.data += data;
})
);
},
reset: () => {
set(
produce((state: ImageFetchingState) => {
state.status = FetchingStates.IDLE;
state.step = 0;
state.totalSteps = 0;
state.data = '';
})
);
},
setStatus: (status: typeof FetchingStates[keyof typeof FetchingStates]) => {
set(
produce((state: ImageFetchingState) => {
state.status = status;
})
);
},
setStep: (step: number) => {
set(
produce((state: ImageFetchingState) => {
state.step = step;
})
);
},
setTotalSteps: (totalSteps: number) => {
set(
produce((state: ImageFetchingState) => {
state.totalSteps = totalSteps;
})
);
},
}));

View File

@ -370,7 +370,13 @@ def do_mk_img(req: Request):
if req.stream_progress_updates:
n_steps = opt_ddim_steps if req.init_image is None else t_enc
progress = {"step": i, "total_steps": n_steps}
progress = {
"status": "progress",
#"progress": (i + 1) / n_steps,
"progress": {
"step": i, "total_steps": n_steps
}
}
if req.stream_image_progress and i % 5 == 0:
partial_images = []