|
|
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from data import load_db |
|
from names import DATASET_ID, MODEL_ID |
|
from storage import RedisStorage |
|
from utils import weighted_random_sample |
|
import os |
|
import numpy as np |
|
from typing import List, Dict |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
app = FastAPI(title="Play My Emotions API", description="API for emotion-based song recommendation", version="1.0") |
|
|
|
|
|
class EmotionInput(BaseModel): |
|
emotion_text: str |
|
|
|
class SongRecommendation(BaseModel): |
|
emotions: str |
|
recommended_songs: List[Dict[str, str]] |
|
|
|
|
|
USE_STORAGE = os.environ.get("USE_STORAGE", "True").lower() in ("true", "t", "1") |
|
|
|
|
|
def init(): |
|
embeddings = OpenAIEmbeddings(model=MODEL_ID) |
|
dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_ID']}/{DATASET_ID}" |
|
|
|
db = load_db( |
|
dataset_path, |
|
embedding_function=embeddings, |
|
token=os.environ["ACTIVELOOP_TOKEN"], |
|
read_only=True, |
|
) |
|
|
|
storage = RedisStorage( |
|
host=os.environ["UPSTASH_URL"], password=os.environ["UPSTASH_PASSWORD"] |
|
) |
|
prompt = PromptTemplate( |
|
input_variables=["user_input"], |
|
template=Path("prompts/bot.prompt").read_text(), |
|
) |
|
|
|
llm = ChatOpenAI(temperature=0.3) |
|
chain = LLMChain(llm=llm, prompt=prompt) |
|
|
|
return db, storage, chain |
|
|
|
db, storage, chain = init() |
|
|
|
|
|
@app.post("/recommend", response_model=SongRecommendation) |
|
async def recommend_song(emotion: EmotionInput): |
|
user_input = emotion.emotion_text |
|
if not user_input: |
|
raise HTTPException(status_code=400, detail="Emotion input is required") |
|
|
|
docs, emotions = get_song(user_input, k=20) |
|
|
|
|
|
recommended_songs = [{"name": doc.metadata["name"], "embed_url": doc.metadata["embed_url"]} for doc in docs] |
|
|
|
return {"emotions": emotions, "recommended_songs": recommended_songs} |
|
|
|
|
|
def get_song(user_input: str, k: int = 20): |
|
emotions = chain.run(user_input=user_input) |
|
matches = db.similarity_search_with_score(emotions, distance_metric="cos", k=k) |
|
docs, scores = zip( |
|
*normalize_scores_by_sum(filter_scores(matches, 0.8)) |
|
) |
|
choosen_docs = weighted_random_sample( |
|
np.array(docs), np.array(scores), n=2 |
|
).tolist() |
|
return choosen_docs, emotions |
|
|
|
|
|
def filter_scores(matches, th: float = 0.8): |
|
return [(doc, score) for (doc, score) in matches if score > th] |
|
|
|
|
|
def normalize_scores_by_sum(matches): |
|
scores = [score for _, score in matches] |
|
tot = sum(scores) |
|
return [(doc, (score / tot)) for doc, score in matches] |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|