Spaces:
Running
Running
# from dotenv import find_dotenv, load_dotenv | |
# _ = load_dotenv(find_dotenv()) | |
import solara | |
import polars as pl | |
df = pl.read_csv( | |
"https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU" | |
) | |
import string | |
df = df.with_columns( | |
pl.Series("Album", [string.capwords(album) for album in df["Album"]]) | |
) | |
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]])) | |
df = df.with_columns(pl.col("Lyrics").fill_null("None")) | |
df = df.with_columns( | |
text=pl.lit("# ") | |
+ pl.col("Album") | |
+ pl.lit(": ") | |
+ pl.col("Song") | |
+ pl.lit("\n\n") | |
+ pl.col("Lyrics") | |
) | |
import shutil | |
import lancedb | |
shutil.rmtree("test_lancedb", ignore_errors=True) | |
db = lancedb.connect("test_lancedb") | |
from lancedb.embeddings import get_registry | |
embeddings = ( | |
get_registry() | |
.get("sentence-transformers") | |
.create(name="TaylorAI/gte-tiny", device="cpu") | |
) | |
from lancedb.pydantic import LanceModel, Vector | |
class Songs(LanceModel): | |
Song: str | |
Lyrics: str | |
Album: str | |
Artist: str | |
text: str = embeddings.SourceField() | |
vector: Vector(embeddings.ndims()) = embeddings.VectorField() | |
table = db.create_table("Songs", schema=Songs) | |
table.add(data=df) | |
import os | |
from typing import Optional | |
from langchain_community.chat_models import ChatOpenAI | |
class ChatOpenRouter(ChatOpenAI): | |
openai_api_base: str | |
openai_api_key: str | |
model_name: str | |
def __init__( | |
self, | |
model_name: str, | |
openai_api_key: Optional[str] = None, | |
openai_api_base: str = "https://openrouter.ai/api/v1", | |
**kwargs, | |
): | |
openai_api_key = os.getenv("OPENROUTER_API_KEY") | |
super().__init__( | |
openai_api_base=openai_api_base, | |
openai_api_key=openai_api_key, | |
model_name=model_name, | |
**kwargs, | |
) | |
llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct", temperature=0.1) | |
def get_relevant_texts(query, table=table): | |
results = ( | |
table.search(query) | |
.limit(5) | |
.to_polars() | |
) | |
return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)]) | |
def generate_prompt(query, table=table): | |
return ( | |
"Answer the question based only on the following context:\n\n" | |
+ get_relevant_texts(query, table) | |
+ "\n\nQuestion: " | |
+ query | |
) | |
def generate_response(query, table=table): | |
prompt = generate_prompt(query, table) | |
response = llm_openrouter.invoke(input=prompt) | |
return response.content | |
import kuzu | |
shutil.rmtree("test_kuzudb", ignore_errors=True) | |
db = kuzu.Database("test_kuzudb") | |
conn = kuzu.Connection(db) | |
# Create schema | |
conn.execute("CREATE NODE TABLE ARTIST(name STRING, PRIMARY KEY (name))") | |
conn.execute("CREATE NODE TABLE ALBUM(name STRING, PRIMARY KEY (name))") | |
conn.execute("CREATE NODE TABLE SONG(ID SERIAL, name STRING, lyrics STRING, PRIMARY KEY(ID))") | |
conn.execute("CREATE REL TABLE IN_ALBUM(FROM SONG TO ALBUM)") | |
conn.execute("CREATE REL TABLE FROM_ARTIST(FROM ALBUM TO ARTIST)"); | |
# Insert nodes | |
for artist in df["Artist"].unique(): | |
conn.execute(f"CREATE (artist:ARTIST {{name: '{artist}'}})") | |
for album in df["Album"].unique(): | |
conn.execute(f"""CREATE (album:ALBUM {{name: "{album}"}})""") | |
for song, lyrics in df.select(["Song", "text"]).unique().rows(): | |
replaced_lyrics = lyrics.replace('"', "'") | |
conn.execute( | |
f"""CREATE (song:SONG {{name: "{song}", lyrics: "{replaced_lyrics}"}})""" | |
) | |
# Insert edges | |
for song, album, lyrics in df.select(["Song", "Album", "text"]).rows(): | |
replaced_lyrics = lyrics.replace('"', "'") | |
conn.execute( | |
f""" | |
MATCH (song:SONG), (album:ALBUM) | |
WHERE song.name = "{song}" AND song.lyrics = "{replaced_lyrics}" AND album.name = "{album}" | |
CREATE (song)-[:IN_ALBUM]->(album) | |
""" | |
) | |
for album, artist in df.select(["Album", "Artist"]).unique().rows(): | |
conn.execute( | |
f""" | |
MATCH (album:ALBUM), (artist:ARTIST) WHERE album.name = "{album}" AND artist.name = "{artist}" | |
CREATE (album)-[:FROM_ARTIST]->(artist) | |
""" | |
) | |
response = conn.execute( | |
""" | |
MATCH (a:ALBUM {name: 'The Black Album'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name | |
""" | |
) | |
df_response = response.get_as_pl() | |
from langchain_community.graphs import KuzuGraph | |
graph = KuzuGraph(db) | |
def generate_kuzu_prompt(user_query): | |
return """Task: Generate Kùzu Cypher statement to query a graph database. | |
Instructions: | |
Generate the Kùzu dialect of Cypher with the following rules in mind: | |
1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. | |
2. Do not include triple backticks ``` in your response. Return only Cypher. | |
3. Do not return any notes or comments in your response. | |
Use only the provided relationship types and properties in the schema. | |
Do not use any other relationship types or properties that are not provided. | |
Schema:\n""" + graph.get_schema + """\nExample: | |
The question is:\n"Which songs does the load album have?" | |
MATCH (a:ALBUM {name: 'Load'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name | |
Note: Do not include any explanations or apologies in your responses. | |
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. | |
Do not include any text except the generated Cypher statement. | |
The question is:\n""" + user_query | |
def generate_final_prompt(query,cypher_query,col_name,_values): | |
return f"""You are an assistant that helps to form nice and human understandable answers. | |
The information part contains the provided information that you must use to construct an answer. | |
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. | |
Make the answer sound as a response to the question. Do not mention that you based the result on the given information. | |
Here is an example: | |
Question: Which managers own Neo4j stocks? | |
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] | |
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. | |
Follow this example when generating answers. | |
If the provided information is empty, say that you don't know the answer. | |
Query:\n{cypher_query} | |
Information: | |
[{col_name}: {_values}] | |
Question: {query} | |
Helpful Answer: | |
""" | |
def generate_kg_response(query): | |
prompt = generate_kuzu_prompt(query) | |
cypher_query_response = llm_openrouter.invoke(input=prompt) | |
cypher_query = cypher_query_response.content | |
response = conn.execute( | |
f""" | |
{cypher_query} | |
""" | |
) | |
df = response.get_as_pl() | |
col_name = df.columns[0] | |
_values = df[col_name].to_list() | |
final_prompt = generate_final_prompt(query,cypher_query,col_name,_values) | |
final_response = llm_openrouter.invoke(input=final_prompt) | |
final_response = final_response.content | |
return final_response, cypher_query | |
query = solara.reactive("How many songs does the black album have?") | |
def Page(): | |
with solara.Column(margin=10): | |
solara.Markdown("# Metallica Song Finder graph-only") | |
solara.InputText("Enter some query:", query, continuous_update=False) | |
if query.value != "": | |
response, cypher_query = generate_kg_response(query.value) | |
solara.Markdown("## Answer:") | |
solara.Markdown(response) | |
solara.Markdown("## Cypher query:") | |
solara.Markdown(cypher_query) | |