File size: 4,389 Bytes
30d06f3 ec4e3bf 6c0ac6b 4b65fd2 142b484 6c0ac6b dd17730 015696a 6c0ac6b 4c1fca2 19a01a7 31fc42e 9ca2069 ec4e3bf 6c0ac6b 142b484 27ed5a9 142b484 9ca2069 6c0ac6b 015696a 6c0ac6b 19a01a7 142b484 31fc42e 9ca2069 31fc42e b0261c2 31fc42e 142b484 9ca2069 142b484 31fc42e 27fbf7a 142b484 ec4e3bf 4b65fd2 ec4e3bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import time
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from TextGen.suno import custom_generate_audio, get_audio_information
from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)
from TextGen import app
from gradio_client import Client
class Message(BaseModel):
npc: str | None = None
input: str | None = None
class VoiceMessage(BaseModel):
npc: str | None = None
input: str | None = None
language: str | None = "en"
genre:str | None = "Male"
song_base_api=os.environ["VERCEL_API"]
my_hf_token=os.environ["HF_TOKEN"]
tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token)
main_npcs={
"Blacksmith":"./voices/blacksmith.mp3",
"Herbalist":"./voices/female.wav"
}
class Generate(BaseModel):
text:str
def generate_text(prompt: str):
if prompt == "":
return {"detail": "Please provide a prompt."}
else:
prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
# Initialize the LLM
llm = ChatGoogleGenerativeAI(
model="gemini-pro",
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
)
llmchain = LLMChain(
prompt=prompt,
llm=llm
)
llm_response = llmchain.run({"Prompt": prompt})
return Generate(text=llm_response)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", tags=["Home"])
def api_home():
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(message: Message):
return generate_text(prompt=message.input)
#Dummy function for now
def determine_vocie_from_npc(npc,genre):
if npc in main_npcs:
return main_npcs[npc]
else:
if genre =="Male":
"./voices/blacksmith.mp3"
if genre=="Female":
return"./voices/female.wav"
else:
return "./voices/narator_out.wav"
@app.post("/generate_wav")
async def generate_wav(message:VoiceMessage):
try:
voice=determine_vocie_from_npc(message.npc, message.genre)
# Use the Gradio client to generate the wav file
result = tts_client.predict(
message.input, # str in 'Text Prompt' Textbox component
message.language, # str in 'Language' Dropdown component
voice, # str (filepath on your computer (or URL) of file) in 'Reference Audio' Audio component
voice, # str (filepath on your computer (or URL) of file) in 'Use Microphone for Reference' Audio component
False, # bool in 'Use Microphone' Checkbox component
False, # bool in 'Cleanup Reference Voice' Checkbox component
False, # bool in 'Do not use language auto-detect' Checkbox component
True, # bool in 'Agree' Checkbox component
fn_index=1
)
# Get the path of the generated wav file
wav_file_path = result[1]
# Return the generated wav file as a response
return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/generate_song")
async def generate_song(text: str):
try:
data = custom_generate_audio({
"prompt": f"{text}",
"make_instrumental": False,
"wait_audio": False
})
ids = f"{data[0]['id']},{data[1]['id']}"
print(f"ids: {ids}")
for _ in range(60):
data = get_audio_information(ids)
if data[0]["status"] == 'streaming':
print(f"{data[0]['id']} ==> {data[0]['audio_url']}")
print(f"{data[1]['id']} ==> {data[1]['audio_url']}")
break
# sleep 5s
time.sleep(5)
except:
print("Error") |