import os
os.environ['LIBROSA_CACHE_LEVEL'] = '0'
os.environ["NUMBA_DISABLE_JIT"] = "1"
os.makedirs("/var/www/html/eduruby.in/lip-sync/numba_cache", exist_ok=True)
os.environ["NUMBA_CACHE_DIR"] = "/var/www/html/eduruby.in/lip-sync/numba_cache"
import cv2
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
import subprocess
import json
import tempfile

# --- Core imports
from wav2lip_core.models import Wav2Lip
from wav2lip_core.audio import melspectrogram
from wav2lip_core.face_detection.detection.sfd.sfd_detector import SFDDetector

# -------------------------------------------------------------
#  AUDIO PROCESSING WITHOUT LIBROSA
# -------------------------------------------------------------
def get_audio_duration(audio_path):
    """Get audio duration using ffprobe"""
    try:
        cmd = [
            'ffprobe', '-v', 'quiet', '-print_format', 'json',
            '-show_format', '-show_streams', audio_path
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        info = json.loads(result.stdout)
        
        if 'format' in info and 'duration' in info['format']:
            duration = float(info['format']['duration'])
            print(f"[INFO] Audio duration from ffprobe: {duration:.2f}s")
            return duration
        
        for stream in info.get('streams', []):
            if stream.get('codec_type') == 'audio' and 'duration' in stream:
                duration = float(stream['duration'])
                print(f"[INFO] Audio duration from stream: {duration:.2f}s")
                return duration
                
        raise Exception("No duration found in audio file")
        
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] ffprobe failed: {e}")
        raise Exception(f"Could not get audio duration: {e}")
    except Exception as e:
        print(f"[ERROR] Error processing audio duration: {e}")
        raise

def load_audio_ffmpeg(audio_path, target_sr=16000):
    """Load audio using ffmpeg instead of librosa"""
    try:
        # Create temporary file for raw audio
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
            temp_path = temp_file.name
        
        # Convert to WAV with target sample rate using ffmpeg
        cmd = [
            'ffmpeg', '-y', '-i', audio_path,
            '-ac', '1', '-ar', str(target_sr),
            '-acodec', 'pcm_s16le', '-f', 'wav',
            temp_path
        ]
        subprocess.run(cmd, check=True, capture_output=True)
        
        # Read the WAV file
        import wave
        import struct
        
        with wave.open(temp_path, 'rb') as wav_file:
            frames = wav_file.readframes(wav_file.getnframes())
            audio_data = struct.unpack(f"{len(frames)//2}h", frames)
            audio_array = np.array(audio_data, dtype=np.float32) / 32768.0
        
        # Clean up
        os.unlink(temp_path)
        
        return audio_array, target_sr
        
    except Exception as e:
        print(f"[ERROR] Failed to load audio with ffmpeg: {e}")
        # Fallback: try using soundfile if available
        try:
            import soundfile as sf
            audio, sr = sf.read(audio_path)
            if sr != target_sr:
                # Resample if needed
                from scipy import signal
                if audio.ndim > 1:
                    audio = audio.mean(axis=1)  # Convert to mono
                audio = signal.resample(audio, int(len(audio) * target_sr / sr))
            return audio.astype(np.float32), target_sr
        except ImportError:
            raise Exception("Need soundfile or working ffmpeg for audio processing")

# -------------------------------------------------------------
#  DEVICE SETUP
# -------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INIT] Using device: {device.upper()}")

# -------------------------------------------------------------
#  LOAD WAV2LIP MODEL
# -------------------------------------------------------------
def load_model(path):
    print(f"[INFO] Loading Wav2Lip model from {path} ...")
    model = Wav2Lip()
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    state_dict = checkpoint["state_dict"]
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
    model = model.to(device).eval()
    print("[INFO] Model loaded successfully!")
    return model

# -------------------------------------------------------------
#  AUDIO → MEL CHUNKS (WITHOUT LIBROSA)
# -------------------------------------------------------------
def prepare_mel_chunks(audio_path, fps):
    # Use ffmpeg instead of librosa for audio loading
    wav, sr = load_audio_ffmpeg(audio_path, target_sr=16000)
    
    # Use the existing melspectrogram function from wav2lip_core
    mel = melspectrogram(wav)
    mel_step_size = 16
    mel_idx_multiplier = 80.0 / fps
    mel_chunks = []

    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > mel.shape[1]:
            break
        mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
        i += 1

    return mel_chunks

# -------------------------------------------------------------
#  LIP-SYNC GENERATION FUNCTION
# -------------------------------------------------------------
def generate_lip_sync(face_path, audio_path, output_path):
    """Generate lip-synced video using a face video and a separate input audio."""
    try:
        model_path = "wav2lip_core/checkpoints/wav2lip.pth"
        model = load_model(model_path)

        print("[INFO] Loading SFD face detector...")
        face_detector = SFDDetector(device=device)

        # --- Step 1: Remove original audio from video (mute)
        muted_video = face_path.replace(".mp4", "_muted.mp4")
        os.system(f"ffmpeg -y -i {face_path} -an -vcodec copy {muted_video}")
        face_path = muted_video

        # --- Step 2: Load frames from muted video
        cap = cv2.VideoCapture(face_path)
        fps = cap.get(cv2.CAP_PROP_FPS) or 25
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        cap.release()
        print(f"[INFO] Video loaded: {len(frames)} frames @ {fps} fps")

        # --- Step 3: Compute durations (using FFmpeg instead of librosa)
        audio_duration = get_audio_duration(audio_path)
        video_duration = len(frames) / fps

        # --- Step 4: Loop video to match audio length
        if audio_duration > video_duration:
            loop_count = int(np.ceil(audio_duration / video_duration))
            frames = (frames * loop_count)[: int(audio_duration * fps)]
            print(f"[INFO] Looping video {loop_count}x to match {audio_duration:.2f}s audio duration.")

        # --- Step 5: Prepare mel chunks (using FFmpeg instead of librosa)
        mel_chunks = prepare_mel_chunks(audio_path, fps)
        print(f"[INFO] Mel chunks: {len(mel_chunks)}")

        min_len = min(len(frames), len(mel_chunks))
        frames = frames[:min_len]
        mel_chunks = mel_chunks[:min_len]

        # --- Step 6: Detect faces frame-by-frame
        print("[INFO] Detecting faces...")
        batch_faces = []
        for frame in tqdm(frames, desc="Detecting"):
            bboxes = face_detector.detect_from_image(frame)
            if len(bboxes) == 0:
                batch_faces.append(None)
                continue
            x1, y1, x2, y2 = map(int, bboxes[0][:4])
            face = frame[y1:y2, x1:x2]
            batch_faces.append((face, (x1, y1, x2, y2)))

        # --- Step 7: Lip-sync inference
        print("[INFO] Performing lip-sync inference...")
        out_frames = []
        for i, face_data in enumerate(tqdm(batch_faces, desc="Lip-sync")):
            if face_data is None:
                out_frames.append(frames[i])
                continue
            face, (x1, y1, x2, y2) = face_data
            face = cv2.resize(face, (96, 96)) / 255.0
            masked_face = face.copy()
            masked_face[48:, :] = 0

            face_input = np.concatenate([masked_face, face], axis=2)
            face_tensor = torch.FloatTensor(face_input).permute(2, 0, 1).unsqueeze(0).to(device)
            mel_tensor = torch.FloatTensor(mel_chunks[i]).unsqueeze(0).unsqueeze(0).to(device)

            with torch.no_grad():
                pred = model(mel_tensor, face_tensor)

            pred = (pred.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            pred_resized = cv2.resize(pred, (x2 - x1, y2 - y1))
            frame = frames[i].copy()
            frame[y1:y2, x1:x2] = pred_resized
            out_frames.append(frame)

        # --- Step 8: Save temporary silent video
        print("[INFO] Saving generated video (without audio)...")
        Path(os.path.dirname(output_path)).mkdir(parents=True, exist_ok=True)
        temp_video_path = output_path.replace(".mp4", "_no_audio.mp4")

        h, w = out_frames[0].shape[:2]
        out = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
        for f in out_frames:
            out.write(f)
        out.release()

        # --- Step 9: Add input audio and finalize
        print("[INFO] Adding uploaded audio to the video...")
        os.system(f"ffmpeg -y -i {temp_video_path} -i {audio_path} -c:v copy -c:a aac -shortest {output_path}")

        # Clean up temporary files
        if os.path.exists(temp_video_path):
            os.remove(temp_video_path)
        if os.path.exists(muted_video):
            os.remove(muted_video)

        print(f"[DONE] Final lip-sync video generated: {output_path}")

    except Exception as e:
        print(f"[ERROR] Lip-sync generation failed: {e}")
        raise e