prithivMLmods commited on
Commit
e5d74a2
·
verified ·
1 Parent(s): e42e588

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -314
app.py DELETED
@@ -1,314 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import torchaudio
4
- import numpy as np
5
- import os
6
- import tempfile
7
- import spaces
8
-
9
- from typing import Iterable
10
- from gradio.themes import Soft
11
- from gradio.themes.utils import colors, fonts, sizes
12
-
13
- colors.orange_red = colors.Color(
14
- name="orange_red",
15
- c50="#FFF0E5",
16
- c100="#FFE0CC",
17
- c200="#FFC299",
18
- c300="#FFA366",
19
- c400="#FF8533",
20
- c500="#FF4500",
21
- c600="#E63E00",
22
- c700="#CC3700",
23
- c800="#B33000",
24
- c900="#992900",
25
- c950="#802200",
26
- )
27
-
28
- class OrangeRedTheme(Soft):
29
- def __init__(
30
- self,
31
- *,
32
- primary_hue: colors.Color | str = colors.gray,
33
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
34
- neutral_hue: colors.Color | str = colors.slate,
35
- text_size: sizes.Size | str = sizes.text_lg,
36
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
37
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
38
- ),
39
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
40
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
41
- ),
42
- ):
43
- super().__init__(
44
- primary_hue=primary_hue,
45
- secondary_hue=secondary_hue,
46
- neutral_hue=neutral_hue,
47
- text_size=text_size,
48
- font=font,
49
- font_mono=font_mono,
50
- )
51
- super().set(
52
- background_fill_primary="*primary_50",
53
- background_fill_primary_dark="*primary_900",
54
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
55
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
56
- button_primary_text_color="white",
57
- button_primary_text_color_hover="white",
58
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
60
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
61
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
- button_secondary_text_color="black",
63
- button_secondary_text_color_hover="white",
64
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
65
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
66
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
67
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
68
- slider_color="*secondary_500",
69
- slider_color_dark="*secondary_600",
70
- block_title_text_weight="600",
71
- block_border_width="3px",
72
- block_shadow="*shadow_drop_lg",
73
- button_primary_shadow="*shadow_drop_lg",
74
- button_large_padding="11px",
75
- color_accent_soft="*primary_100",
76
- block_label_background_fill="*primary_200",
77
- )
78
-
79
- orange_red_theme = OrangeRedTheme()
80
-
81
- try:
82
- from sam_audio import SAMAudio, SAMAudioProcessor
83
- except ImportError as e:
84
- print(f"Warning: 'sam_audio' library not found. Please install it to use this app. Error: {e}")
85
-
86
- MODEL_ID = "facebook/sam-audio-large"
87
- DEFAULT_CHUNK_DURATION = 30.0
88
- OVERLAP_DURATION = 2.0
89
- MAX_DURATION_WITHOUT_CHUNKING = 30.0
90
-
91
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
- print(f"Loading {MODEL_ID} on {device}...")
93
-
94
- model = None
95
- processor = None
96
-
97
- try:
98
- model = SAMAudio.from_pretrained(MODEL_ID).to(device).eval()
99
- processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
100
- print("✅ SAM-Audio loaded successfully.")
101
- except Exception as e:
102
- print(f"❌ Error loading SAM-Audio: {e}")
103
-
104
- def load_audio(file_path):
105
- """Load audio from file (supports both audio and video files)."""
106
- waveform, sample_rate = torchaudio.load(file_path)
107
- if waveform.shape[0] > 1:
108
- waveform = waveform.mean(dim=0, keepdim=True)
109
- return waveform, sample_rate
110
-
111
- def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
112
- """Split audio waveform into overlapping chunks."""
113
- chunk_samples = int(chunk_duration * sample_rate)
114
- overlap_samples = int(overlap_duration * sample_rate)
115
- stride = chunk_samples - overlap_samples
116
-
117
- chunks = []
118
- total_samples = waveform.shape[1]
119
-
120
- if total_samples <= chunk_samples:
121
- return [waveform]
122
-
123
- start = 0
124
- while start < total_samples:
125
- end = min(start + chunk_samples, total_samples)
126
- chunk = waveform[:, start:end]
127
- chunks.append(chunk)
128
- if end >= total_samples:
129
- break
130
- start += stride
131
-
132
- return chunks
133
-
134
- def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
135
- """Merge audio chunks with crossfade on overlapping regions."""
136
- if len(chunks) == 1:
137
- chunk = chunks[0]
138
- if chunk.dim() == 1:
139
- chunk = chunk.unsqueeze(0)
140
- return chunk
141
-
142
- overlap_samples = int(overlap_duration * sample_rate)
143
-
144
- processed_chunks = []
145
- for chunk in chunks:
146
- if chunk.dim() == 1:
147
- chunk = chunk.unsqueeze(0)
148
- processed_chunks.append(chunk)
149
-
150
- result = processed_chunks[0]
151
-
152
- for i in range(1, len(processed_chunks)):
153
- prev_chunk = result
154
- next_chunk = processed_chunks[i]
155
-
156
- actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
157
-
158
- if actual_overlap <= 0:
159
- result = torch.cat([prev_chunk, next_chunk], dim=1)
160
- continue
161
-
162
- fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
163
- fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
164
-
165
- prev_overlap = prev_chunk[:, -actual_overlap:]
166
- next_overlap = next_chunk[:, :actual_overlap]
167
-
168
- crossfaded = prev_overlap * fade_out + next_overlap * fade_in
169
-
170
- result = torch.cat([
171
- prev_chunk[:, :-actual_overlap],
172
- crossfaded,
173
- next_chunk[:, actual_overlap:]
174
- ], dim=1)
175
-
176
- return result
177
-
178
- def save_audio(tensor, sample_rate):
179
- """Saves a tensor to a temporary WAV file and returns path."""
180
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
181
- tensor = tensor.cpu()
182
- if tensor.dim() == 1:
183
- tensor = tensor.unsqueeze(0)
184
- torchaudio.save(tmp.name, tensor, sample_rate)
185
- return tmp.name
186
-
187
- @spaces.GPU(duration=120)
188
- def process_audio(file_path, text_prompt, chunk_duration_val, progress=gr.Progress()):
189
- global model, processor
190
-
191
- if model is None or processor is None:
192
- return None, None, "❌ Model not loaded correctly. Check logs."
193
-
194
- progress(0.05, desc="Checking inputs...")
195
-
196
- if not file_path:
197
- return None, None, "❌ Please upload an audio or video file."
198
- if not text_prompt or not text_prompt.strip():
199
- return None, None, "❌ Please enter a text prompt."
200
-
201
- try:
202
- progress(0.15, desc="Loading audio...")
203
- waveform, sample_rate = load_audio(file_path)
204
- duration = waveform.shape[1] / sample_rate
205
-
206
- c_dur = chunk_duration_val if chunk_duration_val else DEFAULT_CHUNK_DURATION
207
- use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
208
-
209
- if use_chunking:
210
- progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
211
- chunks = split_audio_into_chunks(waveform, sample_rate, c_dur, OVERLAP_DURATION)
212
- num_chunks = len(chunks)
213
-
214
- target_chunks = []
215
- residual_chunks = []
216
-
217
- for i, chunk in enumerate(chunks):
218
- chunk_progress = 0.2 + (i / num_chunks) * 0.6
219
- progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
220
-
221
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
222
- torchaudio.save(tmp.name, chunk, sample_rate)
223
- chunk_path = tmp.name
224
-
225
- try:
226
- inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
227
-
228
- with torch.inference_mode():
229
- result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
230
-
231
- target_chunks.append(result.target[0].detach().cpu())
232
- residual_chunks.append(result.residual[0].detach().cpu())
233
- finally:
234
- if os.path.exists(chunk_path):
235
- os.unlink(chunk_path)
236
-
237
- progress(0.85, desc="Merging chunks...")
238
- target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
239
- residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
240
-
241
- progress(0.95, desc="Saving results...")
242
- target_path = save_audio(target_merged, sample_rate)
243
- residual_path = save_audio(residual_merged, sample_rate)
244
-
245
- progress(1.0, desc="Done!")
246
- return target_path, residual_path, f"✅ Isolated '{text_prompt}' ({num_chunks} chunks)"
247
-
248
- else:
249
- progress(0.3, desc="Processing audio...")
250
- inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
251
-
252
- progress(0.6, desc="Separating sounds...")
253
- with torch.inference_mode():
254
- result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
255
-
256
- progress(0.9, desc="Saving results...")
257
- sr = processor.audio_sampling_rate
258
- target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sr)
259
- residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sr)
260
-
261
- progress(1.0, desc="Done!")
262
- return target_path, residual_path, f"✅ Isolated '{text_prompt}'"
263
-
264
- except Exception as e:
265
- import traceback
266
- traceback.print_exc()
267
- return None, None, f"❌ Error: {str(e)}"
268
-
269
- css = """
270
- #main-title h1 {font-size: 2.4em}
271
- """
272
-
273
- with gr.Blocks() as demo:
274
- gr.Markdown("# **SAM-Audio-Demo**", elem_id="main-title")
275
- gr.Markdown("Segment and isolate specific music/sounds from audio files using natural language descriptions, powered by [SAM-Audio-Large](https://huggingface.co/facebook/sam-audio-large).")
276
-
277
- with gr.Column(elem_id="col-container"):
278
- with gr.Row():
279
- with gr.Column(scale=1):
280
- input_file = gr.Audio(label="Input Audio", type="filepath")
281
- text_prompt = gr.Textbox(label="Sound to Isolate", placeholder="e.g., 'A man speaking', 'Bird chirping'")
282
-
283
- with gr.Accordion("Advanced Settings", open=False):
284
- chunk_duration_slider = gr.Slider(
285
- minimum=10, maximum=60, value=30, step=5,
286
- label="Chunk Duration (seconds)",
287
- info="Processing long audio in chunks prevents out-of-memory errors."
288
- )
289
-
290
- run_btn = gr.Button("Segment Audio", variant="primary")
291
-
292
- with gr.Column(scale=1):
293
- output_target = gr.Audio(label="Isolated Sound (Target)", type="filepath")
294
- output_residual = gr.Audio(label="Background (Residual)", type="filepath")
295
- status_out = gr.Textbox(label="Status", interactive=False, show_label=True, lines=6)
296
-
297
- gr.Examples(
298
- examples=[
299
- ["example_audio/speech.mp3", "Music", 30],
300
- ["example_audio/song.mp3", "Drum", 30],
301
- ["example_audio/song2.mp3", "Music", 30],
302
- ],
303
- inputs=[input_file, text_prompt, chunk_duration_slider],
304
- label="Audio Examples"
305
- )
306
-
307
- run_btn.click(
308
- fn=process_audio,
309
- inputs=[input_file, text_prompt, chunk_duration_slider],
310
- outputs=[output_target, output_residual, status_out]
311
- )
312
-
313
- if __name__ == "__main__":
314
- demo.launch(theme=orange_red_theme, css=css, mcp_server=True, ssr_mode=False)