Source code for botiverse.preprocessors.Wav2Vec.Wav2Vec
try:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2PhonemeCTCTokenizer
import os
import torch
import torchaudio
from tqdm import tqdm
import numpy as np
# disable warnings from this file
from transformers import logging
except:
pass
[docs]class Wav2Vec():
'''
An interface for transforming audio files into wav2vec vectors.
'''
def __init__(self, sample_rate=16000, duration=1, augment=None):
'''
Initialize the Wav2Vec transformer by loading the wav2vec model and setting the sample rate and duration of the audio files.
:param sample_rate: The sample rate of the audio files
:type sample_rate: int
:param duration: The duration of the audio files in milliseconds
:type duration: int
:param augment: The audio augmentations to apply to the audio files.
:type augment: audiomentations.Compose
'''
logging.set_verbosity_error()
self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.sample_rate = sample_rate
self.duration = duration
self.emb_dim = 768 * int(49 * self.duration) # fact regarding wav2vec2-base-960h
if augment is None:
self.augment = Compose([
TimeStretch(min_rate=0.8, max_rate=1.4, p=0.7),
PitchShift(min_semitones=-4, max_semitones=1, p=0.8),
Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
])
else:
self.augment = augment
[docs] def transform_list(self, words, n=4):
'''
Given a folder dataset with folders each containing audio files, this returns a table of wav2vec vectors (one for each audio file) in the form of a numpy array X and a table of classes in the form of a numpy array y.
Note that in the process, each audio file is augmented n times and each corresponds to another wav2vec vector.
:param words: A list of words which are the classes of the speech classifier.
:type words: list
:param n: The number of times to augment each audio file.
:type n: int
:return: A tuple of the form (X, y) where X is a 3D numpy array representing the wav2vec vectors and y is a 1D numpy array representing the classes of the audio files.
'''
sounds_per_word = len(os.listdir(f"dataset/{words[0]}"))
self.N = len(words) * sounds_per_word
X, y = [], []
print("Transforming audio files into embeddings...")
for word in tqdm(words):
for i, file in enumerate(os.listdir(f"dataset/{word}")):
if not file.endswith(".wav"): continue
waveform, sr = torchaudio.load(f"dataset/{word}/{file}")
# resample if sr != 24K
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
if waveform.shape[0] == 2:
waveform = torch.mean(waveform, dim=0, keepdim=True)
length = waveform.shape[1]
sample_dur = int(self.duration * self.sample_rate)
if length < sample_dur:
waveform = torch.cat((waveform, torch.zeros(1, sample_dur - length)), dim=1)
elif length > sample_dur:
waveform = waveform[:, :sample_dur]
waveform = waveform.squeeze()
waveform = waveform.detach().numpy()
if n >0:
for _ in range(n):
waveform = self.augment(samples=waveform, sample_rate=self.sample_rate)
inputs = self.extractor(waveform, return_tensors="pt", padding=True, sampling_rate=self.sample_rate)
features = self.model(inputs.input_values).last_hidden_state
features = features.squeeze().detach().numpy()
X.append(features)
y.append(words.index(word))
else:
inputs = self.extractor(waveform, return_tensors="pt", padding=True, sampling_rate=self.sample_rate)
features = self.model(inputs.input_values).last_hidden_state
features = features.squeeze().detach().numpy()
X.append(features)
y.append(words.index(word))
X = np.array(X)
y = np.array(y)
return X, y
[docs] def transform(self, path, strict_duration=False):
'''
Convert the audio file as in the path into a wav2vec vector.
:param path: The path to the audio file
:type path: str
:param strict_duration: If True, the audio file is padded or truncated to the duration specified during init.
:type strict_duration: bool
:return: The wav2vec vector of the audio file as a 2D numpy array.
:rtype: numpy.ndarray
'''
waveform, sr = torchaudio.load(path)
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
if waveform.shape[0] == 2: waveform = torch.mean(waveform, dim=0, keepdim=True)
length = waveform.shape[1]
if strict_duration:
sample_dur = int(self.duration * self.sample_rate)
if length < sample_dur:
waveform = torch.cat((waveform, torch.zeros(1, sample_dur - length)), dim=1)
elif length > sample_dur:
waveform = waveform[:, :sample_dur]
waveform = waveform.squeeze()
inputs = self.extractor(waveform, return_tensors="pt", padding=True, sampling_rate=self.sample_rate)
features = self.model(inputs.input_values).last_hidden_state.detach().numpy()
return features
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
from transformers import logging
# Load the pre-trained model and tokenizer
[docs]class Wav2Text():
''' An interface for converting speech files into text using wav2vec2.'''
def __init__(self):
'''Load the pre-trained model and tokenizer'''
logging.set_verbosity_error()
self.model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
[docs] def transcribe(self, path):
'''
Given a path to a speech file, return the transcription of the speech file.
:param path: The path to the speech wav file
:type path: str
:return: The transcription of the speech file
:rtype: str
'''
# load audio
waveform, sample_rate = torchaudio.load(path)
# resample if sr not 16K
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
waveform = waveform.squeeze()
# preprocess the audio
input_values = self.tokenizer(waveform, return_tensors="pt").input_values
# Perform speech-to-text conversion
with torch.no_grad():
logits = self.model(input_values).logits
# Decode the predicted transcription
predicted_ids = torch.argmax(logits, dim=2)
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
return transcription.lower()