evBackend / TextGen /router.py
Jofthomas's picture
Jofthomas HF staff
Update TextGen/router.py
03d50be verified
raw
history blame
4.96 kB
import os
import time
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from TextGen.suno import custom_generate_audio, get_audio_information
from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)
from TextGen import app
from gradio_client import Client
from typing import List
class Message(BaseModel):
npc: str | None = None
messages: List[str] | None = None
class VoiceMessage(BaseModel):
npc: str | None = None
input: str | None = None
language: str | None = "en"
genre:str | None = "Male"
song_base_api=os.environ["VERCEL_API"]
my_hf_token=os.environ["HF_TOKEN"]
tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token)
main_npcs={
"Blacksmith":"./voices/Blacksmith.mp3",
"Herbalist":"./voices/female.mp3",
"Bard":"./voices/Bard_voice.mp3"
}
main_npc_system_prompts={
"Blacksmith":"You are a blacksmith in a video game",
"Herbalist":"You are an herbalist in a video game",
"Bard":"You are a bard in a video game"
}
class Generate(BaseModel):
text:str
def generate_text(messages: List[str], npc:str):
print(npc)
if npc in main_npcs:
system_prompt=main_npc_system_prompts[npc]
else:
system_prompt="you're a character in a video game. Play along."
print(system_prompt)
new_messages=[{"role": "user", "content": system_prompt}]
for index, message in enumerate(messages):
if index%2==0:
new_messages.append({"role": "user", "content": message})
else:
new_messages.append({"role": "assistant", "content": message})
print(new_messages)
# Initialize the LLM
llm = ChatGoogleGenerativeAI(
model="gemini-pro",
max_output_tokens=100,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
)
llm_response = llm.invoke(new_messages)
print(llm_response)
return Generate(text=llm_response.content)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", tags=["Home"])
def api_home():
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(message: Message):
return generate_text(messages=message.messages, npc=message.npc)
#Dummy function for now
def determine_vocie_from_npc(npc,genre):
if npc in main_npcs:
return main_npcs[npc]
else:
if genre =="Male":
"./voices/default_male.mp3"
if genre=="Female":
return"./voices/default_female.mp3"
else:
return "./voices/narator_out.wav"
@app.post("/generate_wav")
async def generate_wav(message:VoiceMessage):
try:
voice=determine_vocie_from_npc(message.npc, message.genre)
# Use the Gradio client to generate the wav file
result = tts_client.predict(
message.input, # str in 'Text Prompt' Textbox component
message.language, # str in 'Language' Dropdown component
voice, # str (filepath on your computer (or URL) of file) in 'Reference Audio' Audio component
voice, # str (filepath on your computer (or URL) of file) in 'Use Microphone for Reference' Audio component
False, # bool in 'Use Microphone' Checkbox component
False, # bool in 'Cleanup Reference Voice' Checkbox component
False, # bool in 'Do not use language auto-detect' Checkbox component
True, # bool in 'Agree' Checkbox component
fn_index=1
)
# Get the path of the generated wav file
wav_file_path = result[1]
# Return the generated wav file as a response
return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/generate_song")
async def generate_song(text: str):
try:
data = custom_generate_audio({
"prompt": f"{text}",
"make_instrumental": False,
"wait_audio": False
})
ids = f"{data[0]['id']},{data[1]['id']}"
print(f"ids: {ids}")
for _ in range(60):
data = get_audio_information(ids)
if data[0]["status"] == 'streaming':
print(f"{data[0]['id']} ==> {data[0]['audio_url']}")
print(f"{data[1]['id']} ==> {data[1]['audio_url']}")
break
# sleep 5s
time.sleep(5)
except:
print("Error")