mirror of
https://github.com/thorstenMueller/Thorsten-Voice.git
synced 2024-11-21 15:33:10 +01:00
Create privateGPT_Voice.py (initial version)
This commit is contained in:
parent
28831d8642
commit
df29c29e09
58
Youtube/privateGPT_Voice.py
Normal file
58
Youtube/privateGPT_Voice.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Script is originally taken from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py
|
||||
# and i've added some STT and TTS stuff.
|
||||
# See tutorial on my Youtube channel here:
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.llms import GPT4All, LlamaCpp
|
||||
import os
|
||||
from constants import CHROMA_SETTINGS
|
||||
|
||||
import whisper
|
||||
from TTS.api import TTS
|
||||
|
||||
##############################
|
||||
# Whisper for local STT part #
|
||||
##############################
|
||||
|
||||
model = whisper.load_model("base")
|
||||
result = model.transcribe("input.wav")
|
||||
query = result["text"]
|
||||
|
||||
##################################
|
||||
# PrivateGPT for query documents #
|
||||
##################################
|
||||
load_dotenv()
|
||||
|
||||
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
|
||||
persist_directory = os.environ.get('PERSIST_DIRECTORY')
|
||||
|
||||
model_type = os.environ.get('MODEL_TYPE')
|
||||
print("Model type: " + model_type)
|
||||
model_path = os.environ.get('MODEL_PATH')
|
||||
print("Model path: " + model_path)
|
||||
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
||||
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
||||
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
|
||||
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
||||
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=None, verbose=False)
|
||||
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
|
||||
|
||||
print("")
|
||||
print(">> Asking my documents: " + query)
|
||||
print("")
|
||||
|
||||
res = qa(query)
|
||||
print("")
|
||||
|
||||
############################
|
||||
# Coqui for local TTS part #
|
||||
############################
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
tts = TTS(model_name="tts_models/en/ljspeech/vits--neon")
|
||||
tts.tts_to_file(text=res['result'])
|
Loading…
Reference in New Issue
Block a user