Source code for wtf_transcript_converter.providers.parakeet

"""
Parakeet provider converter implementation.

This module provides conversion between Parakeet (NVIDIA NeMo) transcription format and WTF.
"""

import os
import re
from typing import Any, Dict, List

from wtf_transcript_converter.core.models import (
    WTFAudio,
    WTFDocument,
    WTFMetadata,
    WTFQuality,
    WTFSegment,
    WTFSpeaker,
    WTFTranscript,
    WTFWord,
)
from wtf_transcript_converter.providers.base import BaseProviderConverter
from wtf_transcript_converter.utils.confidence_utils import normalize_confidence
from wtf_transcript_converter.utils.language_utils import normalize_language_code
from wtf_transcript_converter.utils.time_utils import get_current_iso_timestamp

try:
    import librosa
    import torch
    from transformers import pipeline

    HF_AVAILABLE = True
except ImportError:
    HF_AVAILABLE = False


[docs] class ParakeetConverter(BaseProviderConverter): """ Converter for Parakeet (NVIDIA NeMo) transcription format to/from WTF. """ provider_name: str = "parakeet" description: str = "NVIDIA Parakeet speech recognition via Hugging Face" status: str = "Implemented"
[docs] def __init__( self, provider_name: str = "parakeet", model_name: str = "nvidia/parakeet-tdt-0.6b-v3", ) -> None: super().__init__(provider_name) self.model_name = model_name self._pipeline = None self._tokenizer = None self._model = None
def _load_model(self) -> None: """Load the Parakeet model and tokenizer.""" if not HF_AVAILABLE: raise ImportError( "Hugging Face transformers, torch, librosa, and soundfile are required for Parakeet support" ) if self._pipeline is None: # Set Hugging Face token if available hf_token = os.getenv("HF_TOKEN") if hf_token: os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token try: # Load the ASR pipeline self._pipeline = pipeline( "automatic-speech-recognition", model=self.model_name, token=hf_token, device=0 if torch.cuda.is_available() else -1, ) except Exception as e: raise RuntimeError(f"Failed to load Parakeet model {self.model_name}: {e}")
[docs] def transcribe_audio(self, audio_path: str, language: str = "en") -> Dict[str, Any]: """ Transcribe audio file using Parakeet model. Args: audio_path: Path to audio file language: Language code (e.g., 'en', 'es', 'fr') Returns: Parakeet transcription result """ self._load_model() try: # Load and preprocess audio audio, sample_rate = librosa.load(audio_path, sr=16000) # Transcribe using the pipeline if self._pipeline is None: raise RuntimeError("Pipeline not loaded") result = self._pipeline( audio, return_timestamps=True, chunk_length_s=30, stride_length_s=5, language=language, ) # Convert to our expected format return self._format_parakeet_result(result, audio_path, sample_rate) except Exception as e: raise RuntimeError(f"Parakeet transcription failed: {e}")
def _format_parakeet_result( self, result: Dict[str, Any], audio_path: str, sample_rate: int ) -> Dict[str, Any]: """Format Parakeet pipeline result to our expected structure.""" text = result.get("text", "") chunks = result.get("chunks", []) # Calculate duration duration = len(result.get("raw", [])) / sample_rate if "raw" in result else 0.0 # Extract words and segments words = [] segments = [] word_id = 0 segment_id = 0 for chunk in chunks: chunk_text = chunk.get("text", "").strip() chunk_start = chunk.get("timestamp", [0.0, 0.0])[0] chunk_end = chunk.get("timestamp", [0.0, 0.0])[1] if chunk_text: # Create segment segment = { "id": segment_id, "start": chunk_start, "end": chunk_end, "text": chunk_text, "confidence": 0.9, # Default confidence } segments.append(segment) # Split into words chunk_words = chunk_text.split() segment_word_ids = [] for word_text in chunk_words: if word_text.strip(): word = { "id": word_id, "start": chunk_start + (word_id * (chunk_end - chunk_start) / len(chunk_words)), "end": chunk_start + ((word_id + 1) * (chunk_end - chunk_start) / len(chunk_words)), "text": word_text.strip(), "confidence": 0.9, } words.append(word) segment_word_ids.append(word_id) word_id += 1 segment["words"] = segment_word_ids segment_id += 1 return { "text": text, "language": "en", # Default, could be extracted from model "duration": duration, "words": words, "segments": segments, "model": self.model_name, "audio_path": audio_path, "sample_rate": sample_rate, }
[docs] def convert_to_wtf(self, parakeet_data: Dict[str, Any]) -> WTFDocument: """ Convert Parakeet JSON data to WTF format. Args: parakeet_data: Parakeet JSON data structure Returns: WTF document """ # Extract basic transcript information text = parakeet_data.get("text", "").strip() if not text: text = "[Empty transcript]" # Use a meaningful placeholder for empty transcripts transcript = WTFTranscript( text=text, language=normalize_language_code(parakeet_data.get("language", "en")), duration=parakeet_data.get("duration", 0.0), confidence=self._calculate_overall_confidence(parakeet_data), ) # Convert words words_data = parakeet_data.get("words", []) wtf_words = self._convert_parakeet_words(words_data) # Convert segments segments_data = parakeet_data.get("segments", []) wtf_segments = self._convert_parakeet_segments(segments_data, wtf_words) # Extract speaker information (Parakeet doesn't do diarization by default) speakers = self._extract_speakers(words_data) # Create metadata current_time = get_current_iso_timestamp() audio_duration = parakeet_data.get("duration", 0.0) audio_metadata = WTFAudio( duration=audio_duration, sample_rate=parakeet_data.get("sample_rate", 16000), channels=None, format=None, bitrate=None, ) metadata = WTFMetadata( created_at=current_time, processed_at=current_time, provider=self.provider_name, model=parakeet_data.get("model", self.model_name), processing_time=None, audio=audio_metadata, options={ "audio_path": parakeet_data.get("audio_path"), "model_name": self.model_name, "chunk_length_s": 30, "stride_length_s": 5, }, ) # Clean options to remove None values metadata.options = {k: v for k, v in metadata.options.items() if v is not None} # Calculate quality metrics quality = self._calculate_quality_metrics(parakeet_data, wtf_words) # Preserve other Parakeet-specific fields in extensions extensions = {"parakeet_raw_response": parakeet_data} return WTFDocument( transcript=transcript, segments=wtf_segments, metadata=metadata, words=wtf_words if wtf_words else None, speakers=speakers if speakers else None, alternatives=None, enrichments=None, extensions=extensions if extensions else None, quality=quality, streaming=None, )
[docs] def convert_from_wtf(self, wtf_doc: WTFDocument) -> Dict[str, Any]: """ Convert WTF document to Parakeet JSON format. Args: wtf_doc: WTF document Returns: Parakeet JSON data structure """ parakeet_data: Dict[str, Any] = { "text": wtf_doc.transcript.text, "language": ( wtf_doc.transcript.language.split("-")[0] if "-" in wtf_doc.transcript.language else wtf_doc.transcript.language ), "duration": wtf_doc.transcript.duration, "model": wtf_doc.metadata.model, "words": [], "segments": [], } # Convert words if wtf_doc.words: for word in wtf_doc.words: parakeet_data["words"].append( { "id": word.id, "start": word.start, "end": word.end, "text": word.text, "confidence": word.confidence, } ) # Convert segments if wtf_doc.segments: for segment in wtf_doc.segments: parakeet_data["segments"].append( { "id": segment.id, "start": segment.start, "end": segment.end, "text": segment.text, "confidence": segment.confidence, "words": segment.words or [], } ) # Merge extensions back if available if wtf_doc.extensions and "parakeet_raw_response" in wtf_doc.extensions: original_raw = wtf_doc.extensions["parakeet_raw_response"] parakeet_data.update(original_raw) # Ensure our converted data overrides the raw where appropriate parakeet_data["text"] = wtf_doc.transcript.text parakeet_data["duration"] = wtf_doc.transcript.duration parakeet_data["language"] = wtf_doc.transcript.language.split("-")[0] return parakeet_data
def _calculate_overall_confidence(self, parakeet_data: Dict[str, Any]) -> float: """Calculate overall confidence from Parakeet data.""" words = parakeet_data.get("words", []) if not words: return 0.0 confidences = [float(word.get("confidence", 0.0)) for word in words] return float(sum(confidences) / len(confidences)) def _extract_speakers(self, words_data: List[Dict[str, Any]]) -> Dict[str, WTFSpeaker]: """Extract speaker information from Parakeet words data.""" # Parakeet doesn't do speaker diarization by default # Return a single default speaker if not words_data: return {} # Calculate total speaking time total_time = sum(word.get("end", 0.0) - word.get("start", 0.0) for word in words_data) return { "0": WTFSpeaker( id="0", label="Speaker 1", segments=[0], # Default segment total_time=total_time, confidence=0.9, # Default confidence ) } def _convert_parakeet_words(self, words_data: List[Dict[str, Any]]) -> List[WTFWord]: """Convert Parakeet words to WTF words.""" words = [] for word_data in words_data: word_text = word_data.get("text", "").strip() if not word_text: continue wtf_word = WTFWord( id=word_data.get("id", 0), start=word_data.get("start", 0.0), end=word_data.get("end", 0.0), text=word_text, confidence=normalize_confidence( word_data.get("confidence", 0.0), self.provider_name ), speaker="0", # Default speaker is_punctuation=self._detect_punctuation(word_text), ) words.append(wtf_word) return words def _convert_parakeet_segments( self, segments_data: List[Dict[str, Any]], wtf_words: List[WTFWord] ) -> List[WTFSegment]: """Convert Parakeet segments to WTF segments.""" segments = [] for segment_data in segments_data: segment = WTFSegment( id=segment_data.get("id", 0), start=segment_data.get("start", 0.0), end=segment_data.get("end", 0.0), text=segment_data.get("text", ""), confidence=normalize_confidence( segment_data.get("confidence", 0.0), self.provider_name ), speaker="0", # Default speaker words=segment_data.get("words", []), ) segments.append(segment) return segments def _detect_punctuation(self, word_text: str) -> bool: """Simple check to see if a word is primarily punctuation.""" return bool(re.fullmatch(r"^\W+$", word_text)) def _calculate_quality_metrics( self, parakeet_data: Dict[str, Any], wtf_words: List[WTFWord] ) -> WTFQuality: """Calculate quality metrics based on Parakeet data.""" low_confidence_words = sum(1 for word in wtf_words if word.confidence < 0.5) average_confidence = ( sum(word.confidence for word in wtf_words) / len(wtf_words) if wtf_words else 0.0 ) return WTFQuality( audio_quality=( "high" if average_confidence > 0.8 else "medium" if average_confidence > 0.6 else "low" ), background_noise=None, multiple_speakers=None, overlapping_speech=None, silence_ratio=None, average_confidence=average_confidence, low_confidence_words=low_confidence_words, processing_warnings=[], )