Stanford-TH commited on
Commit
d862c41
1 Parent(s): 41118c7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Script Similarity
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 4.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Script Similarity
3
+ emoji: 🦀
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
ScriptMatcher.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from ast import literal_eval
4
+ import yake
5
+ import spacy
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from sentence_transformers import SentenceTransformer
8
+ import os
9
+
10
+ class ScriptMatcher:
11
+ def __init__(self, data_path = None, model_name='paraphrase-mpnet-base-v2',dataframe = None):
12
+ """
13
+ Initialize the SeriesMatcher object.
14
+
15
+ Parameters:
16
+ data_path (str): Path to the dataset file.
17
+ model_name (str): Name of the sentence transformer model. Default is 'paraphrase-mpnet-base-v2'.
18
+ """
19
+ if data_path is not None:
20
+ self.dataset = pd.read_csv(data_path)
21
+ if dataframe is not None:
22
+ self.dataset = dataframe
23
+ self.model = SentenceTransformer(model_name)
24
+ self.kw_extractor = yake.KeywordExtractor("en", n=1, dedupLim=0.9)
25
+ self.k_dataset = pd.read_csv('models/Similarity_K_Dataset/K_Dataset.csv')
26
+ self._ent_type = ["PERSON","NORP","FAC","ORG","GPE","LOC","PRODUCT","EVENT","WORK","ART","LAW",
27
+ "LANGUAGE","DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]
28
+ self.embeddings_synopsis_list = np.load("models/Similarity_K_Dataset/plot_embeddings.npy")
29
+ self.plot_embedding_list = np.load("models/Similarity_K_Dataset/synopsis_embeddings.npy")
30
+ try:
31
+ self.nlp = spacy.load("en_core_web_sm")
32
+ except:
33
+ print("Downloading spaCy NLP model...")
34
+ os.system(
35
+ "pip install https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl")
36
+ self.nlp = spacy.load("en_core_web_sm")
37
+
38
+ def extract_keywords(self, text):
39
+ """
40
+ Extract keywords from a given text using the YAKE keyword extraction algorithm.
41
+
42
+ Parameters:
43
+ text (str): Text from which to extract keywords.
44
+
45
+ Returns:
46
+ str: A string of extracted keywords joined by spaces.
47
+ """
48
+ extracted_keywords = self.kw_extractor.extract_keywords(text)
49
+ return " ".join([keywords[0] for keywords in extracted_keywords if keywords[0] not in self._ent_type])
50
+
51
+ def preprocess_text(self, text):
52
+ """
53
+ Process a given text to replace named entities and extract keywords.
54
+
55
+ Parameters:
56
+ text (str): The text to process.
57
+
58
+ Returns:
59
+ str: Processed text with named entities replaced and keywords extracted.
60
+ """
61
+
62
+
63
+ doc = self.nlp(text)
64
+ replaced_text = text
65
+ for token in doc:
66
+ if token.ent_type_ != "MISC" and token.ent_type_ != "":
67
+ replaced_text = replaced_text.replace(token.text, f"<{token.ent_type_}>")
68
+
69
+ return self.extract_keywords(replaced_text)
70
+
71
+ def find_similar_series(self, new_synopsis, genres_keywords,k=5):
72
+ """
73
+ Find series similar to a new synopsis.
74
+
75
+ Parameters:
76
+ new_synopsis (str): The synopsis to compare.
77
+ k (int): The number of similar series to return.
78
+
79
+ Returns:
80
+ pd.DataFrame: A dataframe of the closest series.
81
+ """
82
+ processed_synopsis = self.preprocess_text(new_synopsis)
83
+ genre_keywords = " ".join(genres_keywords)
84
+ print(genre_keywords)
85
+ synopsis_sentence = genre_keywords + self.extract_keywords(processed_synopsis)
86
+
87
+ synopsis_embedding = self.model.encode([synopsis_sentence])
88
+
89
+ cosine_similarity_matrix = 0.75 * cosine_similarity(synopsis_embedding, self.embeddings_synopsis_list) + 0.25 * cosine_similarity(synopsis_embedding,self.plot_embedding_list)
90
+
91
+ top_k_indices = cosine_similarity_matrix.argsort()[0, -k:][::-1]
92
+ closest_series = self.k_dataset.iloc[top_k_indices]
93
+
94
+ # Add scores column
95
+ closest_series["Score"] = cosine_similarity_matrix[0, top_k_indices]
96
+
97
+ return closest_series[["Series", "Genre","Score"]].to_dict(orient='records')
98
+
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ScriptMatcher import ScriptMatcher
3
+ # Initialize the ScriptMatcher instance
4
+ scriptmatcher = ScriptMatcher()
5
+
6
+ def classify_movie_genre(description, genres):
7
+ """
8
+ Given a description (synopsis) and genres, return similar series predictions.
9
+ """
10
+ # Split the genres string into a list of keywords
11
+ genre_keywords = genres.split(",") # Assuming genres are comma-separated
12
+ # Get the predictions using the ScriptMatcher
13
+ predictions = scriptmatcher.find_similar_series(description, genre_keywords)
14
+
15
+ return predictions
16
+
17
+ # Create the Gradio interface
18
+ iface = gr.Interface(
19
+ fn=classify_movie_genre,
20
+ inputs=[
21
+ gr.Textbox(lines=5, label="Synopsis (Description)"),
22
+ gr.Textbox(label="Genres (Comma-separated)")
23
+ ],
24
+ outputs=gr.Dataframe(label="Similar Series Predictions"),
25
+ live=False, # No need for live updates as the processing will be based on submission
26
+ title="Genre Prediction",
27
+ description="Provide a movie synopsis and genres to get predictions for similar scripts.",
28
+ )
29
+
30
+ # Launch the Gradio interface
31
+ iface.launch(inline=False)
models/Similarity_K_Dataset/K_Dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
models/Similarity_K_Dataset/plot_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc05423932e01a2907ce69a9832010d116ee64d86e2a19a97bdf28846fd39c92
3
+ size 5222528
models/Similarity_K_Dataset/synopsis_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b48d966a4993e82122d09441875c93a60f47aca960cee908220d1daf5eba7c92
3
+ size 5222528
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pandas==2.2.1
2
+ numpy==1.26.4
3
+ yake==0.4.8
4
+ spacy==3.7.4
5
+ scikit-learn==1.2.2
6
+ sentence-transformers==2.6.1