yangswei commited on
Commit
8a77843
1 Parent(s): 275583c

Update song-insight-app.py

Browse files
Files changed (1) hide show
  1. song-insight-app.py +32 -29
song-insight-app.py CHANGED
@@ -6,6 +6,7 @@ from langchain.chains import LLMChain
6
  from langchain_community.retrievers import WikipediaRetriever
7
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
  from langchain_google_genai import ChatGoogleGenerativeAI
 
9
  import os
10
 
11
 
@@ -18,12 +19,15 @@ def song_insight(song, artist):
18
  docs = retriever.get_relevant_documents(query=query_input)
19
 
20
  # LLM model
21
- # llm = ChatOpenAI(openai_api_key=os.environ['OPENAI_API_KEY'], model_name="gpt-3.5-turbo", temperature=0)
22
- llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=os.environ['GOOGLE_API_KEY'])
23
-
24
- # Emotion Classifier Model
25
- tokenizer = AutoTokenizer.from_pretrained("yangswei/emotion_text_classification")
26
- emotion_model = AutoModelForSequenceClassification.from_pretrained("yangswei/emotion_text_classification")
 
 
 
27
 
28
  # get the song meaning
29
  template_song_meaning = """
@@ -31,43 +35,42 @@ def song_insight(song, artist):
31
 
32
  {content}
33
 
34
- based on the the content above what does the song {song} by {artist} tell us about? give me a long explanations
 
35
 
36
  """
37
- prompt_template_song_meaning = PromptTemplate(input_variables=["artist", "song", "content"],
38
- template=template_song_meaning)
39
  chain_song_meaning = LLMChain(llm=llm, prompt=prompt_template_song_meaning)
40
- results_song_meaning = chain_song_meaning.run(artist=artist.title(), song=song.title(),
41
- content=docs[0].page_content)
42
 
43
- # get the song theme
44
- template_song_theme = """
45
- {artist} has released a song called {song}.
46
 
47
- {content}
 
 
 
 
 
 
48
 
49
- based on the the content above what themes does the lyrics have?
50
 
51
  """
52
- prompt_template_song_theme = PromptTemplate(input_variables=["artist", "song", "content"],
53
- template=template_song_theme)
54
- chain_song_theme = LLMChain(llm=llm, prompt=prompt_template_song_theme)
55
- text_song_theme = chain_song_theme.run(artist=artist.title(), song=song.title(), content=docs[0].page_content)
56
- inputs_song_theme = tokenizer(text_song_theme, return_tensors="pt")
57
- output_song_theme_proba = emotion_model(**inputs_song_theme).logits.softmax(1)
58
- labels = emotion_model.config.id2label
59
- confidences = {labels[i]: output_song_theme_proba[0][i].item() for i in range(len(labels))}
60
 
61
- return results_song_meaning, confidences
 
 
62
 
 
63
 
64
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
65
  song = gr.Textbox(label="Song")
66
  artist = gr.Textbox(label="Artist")
67
  output_song_meaning = gr.Textbox(label="Meaning")
68
- output_song_theme = gr.Label(num_top_classes=6, label="Theme")
69
- gr.Interface(fn=song_insight, inputs=[song, artist], outputs=[output_song_meaning, output_song_theme])
70
- example = gr.Examples([['Life Goes On', 'BTS'], ['Here Comes The Sun', 'The Beatles'],
71
- ['Bedtime Stories', 'Jay Chou'], ['Loser', 'BIGBANG']], [song, artist])
72
 
73
  demo.launch()
 
6
  from langchain_community.retrievers import WikipediaRetriever
7
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
10
  import os
11
 
12
 
 
19
  docs = retriever.get_relevant_documents(query=query_input)
20
 
21
  # LLM model
22
+ # llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo", temperature=0)
23
+ safety_setting = {
24
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
25
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
26
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
27
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
28
+ }
29
+ llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=GOOGLE_API_KEY, temperature=0,
30
+ safety_settings=safety_setting, top_p=0)
31
 
32
  # get the song meaning
33
  template_song_meaning = """
 
35
 
36
  {content}
37
 
38
+ based on the the content above what does the song {song} by {artist} tell us about? give me a long explanations and
39
+ do not bold any text.
40
 
41
  """
42
+ prompt_template_song_meaning = PromptTemplate(input_variables=["artist", "song", "content"], template=template_song_meaning)
 
43
  chain_song_meaning = LLMChain(llm=llm, prompt=prompt_template_song_meaning)
44
+ results_song_meaning = chain_song_meaning.run(artist=artist.title(), song=song.title(), content=docs[0].page_content)
 
45
 
46
+ # get song recom
47
+ template_song_recom = """
48
+ here are the meaning of {song} by {artist}:
49
 
50
+ {song_meaning}
51
+
52
+ can you give me a 3 songs recommendation similar to the meaning of the song above?
53
+ use this format for the output and do not bold any text:
54
+ 1. recommended song 1, with a brief explanation.
55
+ 2. recommended song 2, with a brief explanation.
56
+ 3. recommended song 3, with a brief explanation.
57
 
 
58
 
59
  """
 
 
 
 
 
 
 
 
60
 
61
+ prompt_template_song_recom = PromptTemplate(input_variables=["artist", "song", "song_meaning"], template=template_song_recom)
62
+ chain_song_recom = LLMChain(llm=llm, prompt=prompt_template_song_recom)
63
+ results_song_recom = chain_song_recom.run(artist=artist, song=song, song_meaning=results_song_meaning)
64
 
65
+ return results_song_meaning, results_song_recom
66
 
67
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
68
  song = gr.Textbox(label="Song")
69
  artist = gr.Textbox(label="Artist")
70
  output_song_meaning = gr.Textbox(label="Meaning")
71
+ output_song_recom = gr.Textbox(label="Song Recommendation")
72
+ gr.Interface(fn=song_insight, inputs=[song, artist], outputs=[output_song_meaning, output_song_recom])
73
+ example = gr.Examples([["They Don't Care About Us", 'Michael Jackson'], ["Bad Romance", 'Lady Gaga'],
74
+ ["Let It Be", "The Beatles"], ["Life Goes On", 'BTS'], ["Blank Space", "Taylor Swift"]], [song, artist])
75
 
76
  demo.launch()