Shad0ws doevent commited on
Commit
e1b51d8
0 Parent(s):

Duplicate from doevent/msk

Browse files

Co-authored-by: Max Skobeev <doevent@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +33 -0
  2. README.md +14 -0
  3. app.py +145 -0
  4. constants.py +7 -0
  5. requirements.txt +3 -0
  6. utils.py +50 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Img to Music Video
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.10.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ duplicated_from: doevent/msk
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ import httpx
7
+ import json
8
+
9
+ from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
10
+ #import subprocess
11
+ import os
12
+ import uuid
13
+ from tempfile import gettempdir
14
+ from PIL import Image
15
+ import cv2
16
+ from pprint import pprint
17
+
18
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
19
+ mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
20
+
21
+ # image_to_text = gr.Interface.load("spaces/doevent/image_to_text", api_key=os.environ['HF_TOKEN'])
22
+ image_to_text = gr.Blocks.load(name="spaces/pharma/CLIP-Interrogator")
23
+ def center_crop(img, dim: tuple = (512, 512)):
24
+
25
+ """Returns center cropped image
26
+ Args:
27
+ img: image to be center cropped
28
+ dim: dimensions (width, height) to be cropped
29
+ """
30
+
31
+ width, height = img.shape[1], img.shape[0]
32
+
33
+ # process crop width and height for max available dimension
34
+ crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1]
35
+ crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0]
36
+ mid_x, mid_y = int(width/2), int(height/2)
37
+ cw2, ch2 = int(crop_width/2), int(crop_height/2)
38
+ crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2]
39
+ return crop_img
40
+
41
+
42
+ def scale_image(img, factor=1):
43
+ """Returns resize image by scale factor.
44
+ This helps to retain resolution ratio while resizing.
45
+ Args:
46
+ img: image to be scaled
47
+ factor: scale factor to resize
48
+ """
49
+ return cv2.resize(img,(int(img.shape[1]*factor), int(img.shape[0]*factor)))
50
+
51
+
52
+ def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
53
+ if loop:
54
+ mode = "loop"
55
+ else:
56
+ mode = "track"
57
+ r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
58
+ json={
59
+ "method": "RecordTrackTTM",
60
+ "params": {
61
+ "pat": pat,
62
+ "duration": duration,
63
+ "tags": tags,
64
+ "mode": mode
65
+ }
66
+ })
67
+
68
+ pprint(r.text)
69
+ rdata = json.loads(r.text)
70
+ assert rdata['status'] == 1, rdata['error']['text']
71
+ trackurl = rdata['data']['tasks'][0]['download_link']
72
+
73
+ #print('Generating track ', end='')
74
+ for i in range(maxit):
75
+ r = httpx.get(trackurl)
76
+ if r.status_code == 200:
77
+ return trackurl
78
+ time.sleep(1)
79
+
80
+
81
+ def generate_track_by_prompt(image, email, duration, loop=False):
82
+ try:
83
+ # Checking Image Aspect Ratio
84
+ filename_png = f"{uuid.uuid4().hex}.png"
85
+ filepath_png = f"{gettempdir()}/{filename_png}"
86
+
87
+ with Image.open(image) as im:
88
+ # image size
89
+ ratio_width = im.size[0]
90
+ ratio_height = im.size[1]
91
+ im.convert("RGB").save(filepath_png)
92
+ if ratio_width > 3501 or ratio_height > 3501:
93
+ raise gr.Error("Image aspect ratio must not exceed width: 1024 px or height: 1024 px.")
94
+ elif ratio_width > 3500 or ratio_height > 3500:
95
+ image_g = cv2.imread(image)
96
+ scale_img = scale_image(image_g, factor=0.2)
97
+ cv2.imwrite(filepath_png, scale_img)
98
+ elif ratio_width > 1800 or ratio_height > 1800:
99
+ image_g = cv2.imread(image)
100
+ scale_img = scale_image(image_g, factor=0.3)
101
+ cv2.imwrite(filepath_png, scale_img)
102
+ elif ratio_width > 900 or ratio_height > 900:
103
+ image_g = cv2.imread(image)
104
+ scale_img = scale_image(image_g, factor=0.5)
105
+ cv2.imwrite(filepath_png, scale_img)
106
+
107
+ # prompt = image_to_text(filepath_png, "Image Captioning", "", "Nucleus sampling")
108
+ prompt = image_to_text(filepath_png, "ViT-L (best for Stable Diffusion 1.*)", "Fast", fn_index=1)[0]
109
+ print(f"PROMPT: {prompt}")
110
+
111
+ pat = get_pat(email)
112
+ _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
113
+ filepath = get_track_by_tags(tags, pat, int(duration), loop=loop)
114
+
115
+ filename_mp3 = filepath.split("/")[-1]
116
+ filepath_mp3 = f"{gettempdir()}/{filename_mp3}"
117
+ filename_mp4 = f"{uuid.uuid4().hex}.mp4"
118
+ filepath_mp4 = f"{gettempdir()}/{filename_mp4}"
119
+
120
+ os.system(f"wget {filepath} -P {gettempdir()}")
121
+
122
+ # waveform
123
+ with Image.open(filepath_png) as im:
124
+ width = im.size[0]
125
+ height = im.size[1]
126
+ print(f"{width}x{height}")
127
+ command = f'ffmpeg -hide_banner -loglevel warning -y -i {filepath_mp3} -loop 1 -i {filepath_png} -filter_complex "[0:a]showwaves=s={width}x{height}:colors=0xffffff:mode=cline,format=rgba[v];[1:v][v]overlay[outv]" -map "[outv]" -map 0:a -c:v libx264 -r 15 -c:a copy -pix_fmt yuv420p -shortest {filepath_mp4}'
128
+ os.system(command)
129
+ os.remove(filepath_png)
130
+ os.remove(filepath_mp3)
131
+
132
+ return filepath_mp4, filepath, prompt, tags
133
+ except Exception as e:
134
+ raise gr.Error(str(e))
135
+
136
+
137
+ iface = gr.Interface(fn=generate_track_by_prompt,
138
+ inputs=[gr.Image(type="filepath"),
139
+ "text",
140
+ gr.Slider(label="duration (seconds)", value=30, minimum=10, maximum=60)],
141
+ outputs=[gr.Video(label="Video"),
142
+ gr.Audio(label="Audio"),
143
+ gr.Text(label="Prompt"),
144
+ gr.Text(label="Tags")])
145
+ iface.queue().launch()
constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ MUBERT_TAGS_STRING = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'
4
+ MUBERT_TAGS = np.array(MUBERT_TAGS_STRING.split(','))
5
+ MUBERT_LICENSE = "ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5"
6
+ MUBERT_MODE = "loop"
7
+ MUBERT_TOKEN = "4951f6428e83172a4f39de05d5b3ab10d58560b8"
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ httpx
2
+ sentence-transformers
3
+ opencv-python
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import httpx
4
+
5
+ from constants import MUBERT_TAGS, MUBERT_LICENSE, MUBERT_MODE, MUBERT_TOKEN
6
+
7
+
8
+ def get_mubert_tags_embeddings(w2v_model):
9
+ return w2v_model.encode(MUBERT_TAGS)
10
+
11
+
12
+ def get_pat(email: str):
13
+ r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
14
+ json={
15
+ "method": "GetServiceAccess",
16
+ "params": {
17
+ "email": email,
18
+ "license": MUBERT_LICENSE,
19
+ "token": MUBERT_TOKEN,
20
+ "mode": MUBERT_MODE,
21
+ }
22
+ })
23
+
24
+ rdata = json.loads(r.text)
25
+ assert rdata['status'] == 1, "probably incorrect e-mail"
26
+ pat = rdata['data']['pat']
27
+ return pat
28
+
29
+
30
+ def find_similar(em, embeddings, method='cosine'):
31
+ scores = []
32
+ for ref in embeddings:
33
+ if method == 'cosine':
34
+ scores.append(1 - np.dot(ref, em) / (np.linalg.norm(ref) * np.linalg.norm(em)))
35
+ if method == 'norm':
36
+ scores.append(np.linalg.norm(ref - em))
37
+ return np.array(scores), np.argsort(scores)
38
+
39
+
40
+ def get_tags_for_prompts(w2v_model, mubert_tags_embeddings, prompts, top_n=3, debug=False):
41
+ prompts_embeddings = w2v_model.encode(prompts)
42
+ ret = []
43
+ for i, pe in enumerate(prompts_embeddings):
44
+ scores, idxs = find_similar(pe, mubert_tags_embeddings)
45
+ top_tags = MUBERT_TAGS[idxs[:top_n]]
46
+ top_prob = 1 - scores[idxs[:top_n]]
47
+ if debug:
48
+ print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
49
+ ret.append((prompts[i], list(top_tags)))
50
+ return ret