import { Injectable } from '@angular/core';
import { Subject, Observable, BehaviorSubject } from 'rxjs';
import { filter, first } from 'rxjs/operators';
import { environment } from '../../../../environments/environment.local';
import { RequestType, ResponseType, ResponseStatus } from './sam2.enums';
import { ExecutionContext } from '../../interfaces/chunk/chunk-context.interface';
import { CacheService } from '../cache/cache.service';
import { NavigationEnd, Router } from '@angular/router';
import { ToastrService } from 'ngx-toastr';

@Injectable({
  providedIn: 'root',
})
export class Sam2WebsocketService {
  private socket!: WebSocket;
  public messages: Subject<any> = new Subject<any>();
  private binaryMessages: Subject<ArrayBuffer> = new Subject<ArrayBuffer>();
  private WS_URL: string | undefined = environment.sam2WebsocketUrl;
  private currentSessionId: string = '';
  private promptRequests: Map<
    string,
    { resolve: (value?: unknown) => void; reject: (reason?: any) => void }
  > = new Map();
  private readonly TIMEOUT = 10000;
  private messageQueue: any[] = [];
  private isConnected$ = new BehaviorSubject<boolean>(false);
  private isEnabled: boolean = false;
  private isCreatingSession: boolean = false;
  private sessionCreated$ = new BehaviorSubject<boolean>(false);
  private retryCount: number = 0;
  private readonly MAX_RETRIES = 3;
  private expectedBinaryMessages: number = 0;
  private currentBinaryMessages: ArrayBuffer[] = [];
  private currentScores: number[] = [];
  private isAnnotationToolRoute: boolean = false;

  constructor(
    private cacheService: CacheService,
    private router: Router,
    private toastr: ToastrService
  ) {
    this.isEnabled = !!this.WS_URL;
    this.router.events
      .pipe(filter((event: any) => event instanceof NavigationEnd))
      .subscribe((event: any) => {
        this.isAnnotationToolRoute = event.url === '/annotation-tool';
      });
  }

  public init(): Promise<void> {
    if (!this.isEnabled) {
      console.log('SAM2 WebSocket service is disabled - no URL configured');
      return Promise.resolve();
    }
    return this.connectToWebsocket()
      .then(() => this.ensureSession())
      .catch((error) => {
        console.error('Failed to initialize WebSocket service:', error);
        throw error;
      });
  }

  public isServiceEnabled(): boolean {
    return this.isEnabled;
  }

  private async ensureSession(): Promise<void> {
    if (this.sessionCreated$.value || this.isCreatingSession) {
      return Promise.resolve();
    }

    this.retryCount = 0;
    return this.createSessionWithRetry();
  }

  private async createSessionWithRetry(): Promise<void> {
    try {
      await this.createSession();
    } catch (error) {
      this.retryCount++;
      if (this.retryCount < this.MAX_RETRIES) {
        console.log(
          `Retrying session creation (attempt ${this.retryCount + 1}/${
            this.MAX_RETRIES
          })`
        );
        await new Promise((resolve) =>
          setTimeout(resolve, 1000 * this.retryCount)
        );
        return this.createSessionWithRetry();
      }
      throw error;
    }
  }

  private async createSession(): Promise<void> {
    if (!this.isEnabled) {
      return Promise.resolve();
    }

    if (this.isCreatingSession) {
      return this.waitForSessionCreation();
    }

    try {
      this.isCreatingSession = true;
      this.currentSessionId = crypto.randomUUID();

      const message = {
        session_id: this.currentSessionId,
        request_type: RequestType.CreateSession,
      };

      return this.handleSessionCreation(message);
    } catch (error) {
      this.resetSessionState();
      throw error;
    }
  }

  private waitForSessionCreation(): Promise<void> {
    return new Promise((resolve, reject) => {
      const checkInterval = setInterval(() => {
        if (!this.isCreatingSession) {
          clearInterval(checkInterval);
          if (this.sessionCreated$.value) {
            resolve();
          } else {
            reject(new Error('Session creation failed'));
          }
        }
      }, 100);
    });
  }

  private handleSessionCreation(message: any): Promise<void> {
    return new Promise((resolve, reject) => {
      const timeout = setTimeout(() => {
        this.resetSessionState();
        reject(new Error('Session creation timeout'));
      }, this.TIMEOUT);

      const subscription = this.messages.subscribe({
        next: (data) => {
          if (data.response_type === ResponseType.SessionCreated) {
            this.handleSessionResponse(
              data,
              timeout,
              subscription,
              resolve,
              reject
            );
          } else if (data.response_status === ResponseStatus.Error) {
            this.handleSessionError(data, timeout, subscription, reject);
          }
        },
        error: (error) => {
          this.handleSessionError(error, timeout, subscription, reject);
        },
      });

      this.sendMessage(message);
    });
  }

  private handleSessionResponse(
    data: any,
    timeout: any,
    subscription: any,
    resolve: () => void,
    reject: (reason?: any) => void
  ): void {
    if (data.response_status === ResponseStatus.Ok) {
      clearTimeout(timeout);
      subscription.unsubscribe();
      this.isCreatingSession = false;
      this.sessionCreated$.next(true);
      console.log('Session created successfully:', this.currentSessionId);
      resolve();
    } else {
      this.handleSessionError(
        { message: data.message || 'Unknown error during session creation' },
        timeout,
        subscription,
        reject
      );
    }
  }

  private handleSessionError(
    error: any,
    timeout: any,
    subscription: any,
    reject: (reason?: any) => void
  ): void {
    clearTimeout(timeout);
    subscription.unsubscribe();
    this.resetSessionState();
    const errorMessage = error.message || 'Failed to create session';
    console.error('Session creation error:', errorMessage);
    reject(new Error(errorMessage));
  }

  private resetSessionState(): void {
    this.isCreatingSession = false;
    this.currentSessionId = '';
  }

  private connectToWebsocket(): Promise<void> {
    if (!this.isEnabled) {
      return Promise.resolve();
    }

    if (this.socket?.readyState === WebSocket.OPEN) {
      return Promise.resolve();
    }

    if (this.socket) {
      this.socket.close();
    }

    return new Promise((resolve, reject) => {
      try {
        this.socket = new WebSocket(this.WS_URL!);
        this.socket.binaryType = 'arraybuffer';

        this.socket.addEventListener('open', (event) => {
          console.log('WebSocket connection opened:', event);
          this.isConnected$.next(true);
          this.sendQueuedMessages();
          resolve();
        });

        this.socket.addEventListener('message', (event) => {
          if (event.data instanceof ArrayBuffer) {
            this.handleBinaryMessage(event.data);
          } else {
            this.handleMessage(event);
          }
        });
      } catch (error) {
        reject(error);
      }
    });
  }

  public sendMessage(message: any): void {
    if (!this.isEnabled) {
      console.warn('SAM2 WebSocket service is disabled');
      return;
    }

    if (this.isConnected$.value) {
      if (this.currentSessionId && !message.session_id) {
        message.session_id = this.currentSessionId;
      }
      this.socket.send(JSON.stringify(message));
    } else {
      console.warn('WebSocket is not open. Queueing message.');
      this.messageQueue.push(message);
    }
  }

  private sendQueuedMessages(): void {
    while (this.messageQueue.length > 0) {
      const message = this.messageQueue.shift();
      if (this.currentSessionId && !message.session_id) {
        message.session_id = this.currentSessionId;
      }
      this.socket.send(JSON.stringify(message));
    }
  }

  public receiveMessage(): Observable<any> {
    return this.messages.asObservable();
  }

  public async healthcheck(
    args: any[],
    chunkContext: ExecutionContext
  ): Promise<any> {
    if (!this.isEnabled) {
      const error = new Error('SAM2 WebSocket service is disabled');
      chunkContext.addLog('SAM2 WebSocket service is disabled', 'warning');
      throw error;
    }

    chunkContext.addLog('Initiating healthcheck...', 'info');
    try {
      const message = {
        session_id: this.currentSessionId,
        request_type: RequestType.HealthCheck,
      };
      await this.sendMessageWhenConnected(message);
      chunkContext.addLog('Healthcheck message sent', 'info');

      return this.waitForResponse(
        ResponseType.HealthCheck,
        chunkContext,
        'Healthcheck'
      );
    } catch (error: any) {
      chunkContext.addLog(
        `Error during healthcheck: ${error.message || error}`,
        'error'
      );
      throw error; // Rethrow to allow try/catch blocks to work
    }
  }

  public async sendSingleImage(
    args: any[],
    chunkContext?: ExecutionContext
  ): Promise<any> {
    if (!this.isEnabled) {
      const error = new Error('SAM2 WebSocket service is disabled');
      if (chunkContext) {
        chunkContext.addLog('SAM2 WebSocket service is disabled', 'warning');
      }
      throw error;
    }

    if (!this.validateImageArgs(args, chunkContext)) {
      const error = new Error('Invalid arguments for sendSingleImage');
      if (chunkContext) {
        chunkContext.addLog('Invalid arguments for sendSingleImage', 'error');
      }
      throw error;
    }

    const [imageDataString] = args;

    try {
      await this.ensureSession();
      await this.sendPrepareMessage(RequestType.SendSingleImage, chunkContext);

      return await this.handleImageUpload(imageDataString, chunkContext);
    } catch (error: any) {
      if (chunkContext) {
        chunkContext.addLog(
          `Error sending single image: ${error.message || error}`,
          'error'
        );
      }
      throw error; // Rethrow to allow try/catch blocks to work
    }
  }

  private validateServiceState(): boolean {
    if (!this.isEnabled || !this.currentSessionId) {
      console.warn(
        'SAM2 WebSocket service is disabled or no session available'
      );
      return false;
    }
    return true;
  }

  private validateImageArgs(
    args: any[],
    chunkContext?: ExecutionContext
  ): boolean {
    if (!Array.isArray(args) || args.length < 1) {
      if (chunkContext) {
        chunkContext.addLog(
          'Missing image data: at least one argument with image data is required.',
          'error'
        );
      }
      return false;
    }
    
    if (typeof args[0] !== 'string' || !args[0].trim()) {
      if (chunkContext) {
        chunkContext.addLog(
          'Invalid image data: image must be provided as a non-empty base64 string.',
          'error'
        );
      }
      return false;
    }
    
    return true;
  }

  private async sendPrepareMessage(
    requestType: RequestType,
    chunkContext?: ExecutionContext
  ): Promise<void> {
    const prepareMessage = {
      session_id: this.currentSessionId,
      request_type: requestType,
      session: this.currentSessionId,
    };

    await this.sendMessage(prepareMessage);
    chunkContext?.addLog(
      `${RequestType[requestType]} prepare message sent`,
      'info'
    );
  }

  private async handleImageUpload(
    imageDataString: string,
    chunkContext?: ExecutionContext
  ): Promise<any> {
    return new Promise<any>((resolve, reject) => {
      const timeout = setTimeout(() => {
        subscription.unsubscribe();
        const error = new Error('Timeout waiting for image upload response');
        if (chunkContext) {
          chunkContext.addLog('Timeout waiting for image upload response', 'error');
        }
        reject(error);
      }, this.TIMEOUT);

      const subscription = this.receiveMessage().subscribe({
        next: async (data) => {
          if (data.response_type === ResponseType.SingleImageReady) {
            try {
              const binaryData = this.base64ToBinary(
                imageDataString.includes('base64,')
                  ? imageDataString.split('base64,')[1]
                  : imageDataString
              );
              this.socket.send(binaryData);
              if (chunkContext) {
                chunkContext.addLog('Image data sent', 'info');
              }
            } catch (error: any) {
              this.cleanupSubscription(timeout, subscription);
              if (chunkContext) {
                chunkContext.addLog(`Error processing image data: ${error.message || error}`, 'error');
              }
              reject(error);
            }
          } else if (
            data.response_type === ResponseType.SingleImageReceived &&
            data.response_status === ResponseStatus.Ok
          ) {
            this.cleanupSubscription(timeout, subscription);
            if (chunkContext) {
              chunkContext.addLog('Image received successfully', 'info');
            }
            resolve(data);
          } else if (data.response_status === ResponseStatus.Error) {
            const error = new Error(data.message || 'Error processing image');
            this.handleUploadError(error, timeout, subscription, reject, chunkContext);
          }
        },
        error: (error) => {
          this.handleUploadError(error, timeout, subscription, reject, chunkContext);
        },
      });
    });
  }

  private cleanupSubscription(timeout: any, subscription: any): void {
    clearTimeout(timeout);
    subscription.unsubscribe();
  }

  private handleUploadError(
    error: any,
    timeout: any,
    subscription: any,
    reject: (reason?: any) => void,
    chunkContext?: ExecutionContext
  ): void {
    this.cleanupSubscription(timeout, subscription);
    const errorMessage = error.message || 'Unknown error during image upload';
    if (chunkContext) {
      chunkContext.addLog(`Error during image upload: ${errorMessage}`, 'error');
    }
    reject(new Error(errorMessage));
  }

  private base64ToBinary(base64: string): ArrayBuffer {
    const binaryString = window.atob(base64);
    const bytes = new Uint8Array(binaryString.length);
    for (let i = 0; i < binaryString.length; i++) {
      bytes[i] = binaryString.charCodeAt(i);
    }
    return bytes.buffer;
  }

  public async sendPrompt(
    args: any[],
    chunkContext: ExecutionContext
  ): Promise<any> {
    return this.handlePromptRequest(RequestType.SendPrompt, args, chunkContext);
  }

  public async refinePrompt(
    args: any[],
    chunkContext: ExecutionContext
  ): Promise<any> {
    return this.handlePromptRequest(
      RequestType.RefinePrompt,
      args,
      chunkContext
    );
  }

  private async handlePromptRequest(
    requestType: RequestType,
    args: any[],
    chunkContext: ExecutionContext
  ): Promise<any> {
    if (!this.isEnabled) {
      const error = new Error('SAM2 WebSocket service is disabled');
      chunkContext.addLog('SAM2 WebSocket service is disabled', 'warning');
      throw error;
    }

    if (!this.validatePromptArgs(args, chunkContext)) {
      const error = new Error(`Invalid arguments for ${RequestType[requestType]}`);
      chunkContext.addLog(`Invalid arguments for ${RequestType[requestType]}`, 'error');
      throw error;
    }

    try {
      const message = {
        session_id: this.currentSessionId,
        request_type: requestType,
        prompt: args[0],
      };

      await this.sendMessage(message);
      chunkContext.addLog(`${RequestType[requestType]} request sent`, 'info');

      return this.handlePromptResponse(chunkContext);
    } catch (error: any) {
      chunkContext.addLog(
        `Error with ${RequestType[requestType]}: ${error.message || error}`,
        'error'
      );
      throw error; // Rethrow to allow try/catch blocks to work
    }
  }

  private validatePromptArgs(
    args: any[],
    chunkContext: ExecutionContext
  ): boolean {
    if (!Array.isArray(args) || args.length < 1) {
      chunkContext.addLog(
        'Missing prompt text: prompt requires at least one argument.',
        'error'
      );
      return false;
    }

    if (typeof args[0] !== 'string' || !args[0].trim()) {
      chunkContext.addLog(
        'Invalid prompt: prompt must be a non-empty string.',
        'error'
      );
      return false;
    }

    return true;
  }

  public async handlePromptResponse(
    chunkContext: ExecutionContext
  ): Promise<any> {
    return this.waitForResponse(
      ResponseType.ReceivePrompt,
      chunkContext,
      'prompt'
    );
  }

  private async waitForResponse(
    responseType: ResponseType,
    chunkContext: ExecutionContext,
    operationType: string
  ): Promise<any> {
    return new Promise((resolve, reject) => {
      const timeout = setTimeout(() => {
        chunkContext.addLog(
          `Timeout waiting for ${operationType.toLowerCase()} response`,
          'error'
        );
        reject(new Error(`Timeout waiting for ${operationType.toLowerCase()} response`));
      }, this.TIMEOUT);

      const subscription = this.messages
        .pipe(
          filter((data) => data.response_type === responseType),
          first()
        )
        .subscribe({
          next: (data) => {
            clearTimeout(timeout);
            subscription.unsubscribe();
            if (data.response_status === ResponseStatus.Ok) {
              chunkContext.addLog(
                `${operationType} response received successfully`,
                'info'
              );
              resolve(data);
            } else {
              const error = new Error(data.message || `${operationType} failed`);
              chunkContext.addLog(
                `${operationType} failed: ${data.message || 'Unknown error'}`,
                'error'
              );
              reject(error);
            }
          },
          error: (error) => {
            clearTimeout(timeout);
            chunkContext.addLog(
              `Error waiting for ${operationType.toLowerCase()} response: ${error.message || error}`,
              'error'
            );
            reject(error);
          },
        });
    });
  }

  private handleMessage(event: MessageEvent): void {
    if (event.data instanceof ArrayBuffer) {
      this.handleBinaryMessage(event.data);
      return;
    }

    try {
      const data = JSON.parse(event.data);
      console.log('Received WebSocket message:', data);

      if (data.response_status === ResponseStatus.Error) {
        this.handleErrorMessage(data);
      }

      if (this.isPromptResponseMessage(data)) {
        this.handlePromptResponseMessage(data);
      } else {
        this.messages.next(data);
      }
    } catch (error) {
      console.error('Error parsing WebSocket message:', error);
      this.toastr.error('Error parsing WebSocket message', 'Error');
    }
  }

  private handleErrorMessage(error: any): void {
    const errorMessages: string[] = this.parseErrorMessages(error);

    if (errorMessages.length > 0) {
      errorMessages.forEach((message) => {
        this.toastr.error(message, 'Error');
      });
    }
  }

  private parseErrorMessages(error: any): string[] {
    const errorMessages: string[] = [];

    // Handle direct message property
    if (error.message) {
      errorMessages.push(error.message);
    }

    // Handle nested error object if exists
    if (error.error && typeof error.error === 'object') {
      const errorObj = error.error;

      for (const key in errorObj) {
        if (Object.prototype.hasOwnProperty.call(errorObj, key)) {
          const element = errorObj[key];
          if (Array.isArray(element) && element.length > 0) {
            errorMessages.push(element[0]);
          }
          if (typeof element === 'string' && element !== '') {
            errorMessages.push(element);
          }
        }
      }
    }

    // Handle string error
    if (typeof error.error === 'string' && error.error !== '') {
      errorMessages.push(error.error);
    }

    return errorMessages;
  }

  private isPromptResponseMessage(data: any): boolean {
    return data.response_type === ResponseType.ReceivePrompt;
  }

  private handlePromptResponseMessage(data: any): void {
    if (data.response_status === ResponseStatus.Ok) {
      this.initializeBinaryMessageProcessing(data);
    } else {
      this.messages.next(data);
    }
  }

  private initializeBinaryMessageProcessing(data: any): void {
    this.expectedBinaryMessages = data.scores?.length || 0;
    this.currentScores = data.scores || [];
    this.currentBinaryMessages = [];
    console.log(
      `Expecting ${this.expectedBinaryMessages} binary messages based on scores array`
    );
  }

  public receiveBinaryMessages(): Observable<ArrayBuffer> {
    return this.binaryMessages.asObservable();
  }

  private handleBinaryMessage(data: ArrayBuffer): void {
    if (this.isAnnotationToolRoute) {
      this.handleAnnotationToolBinaryMessage(data);
    } else {
      this.binaryMessages.next(data);
    }
  }

  private handleAnnotationToolBinaryMessage(data: ArrayBuffer): void {
    this.currentBinaryMessages.push(data);

    if (this.currentBinaryMessages.length === this.expectedBinaryMessages) {
      this.processBinaryMessages(
        this.currentBinaryMessages,
        this.currentScores
      ).then((masks) => {
        this.emitProcessedMasks(masks);
        this.resetBinaryMessageState();
      });
    }
  }

  private async processBinaryMessages(
    binaryMessages: ArrayBuffer[],
    scores: number[]
  ): Promise<Array<{ imageData: string; score: number }>> {
    return Promise.all(
      binaryMessages.map((binaryData, index) =>
        this.convertBinaryToMask(binaryData, scores[index])
      )
    );
  }

  private convertBinaryToMask(
    binaryData: ArrayBuffer,
    score: number
  ): Promise<{ imageData: string; score: number }> {
    return new Promise((resolve) => {
      const blob = new Blob([binaryData], { type: 'image/png' });
      const reader = new FileReader();

      reader.onloadend = () => {
        resolve({
          imageData: reader.result as string,
          score: score || 0,
        });
      };

      reader.readAsDataURL(blob);
    });
  }

  private emitProcessedMasks(
    masks: Array<{ imageData: string; score: number }>
  ): void {
    this.messages.next({
      response_type: ResponseType.ReceivePrompt,
      response_status: ResponseStatus.Ok,
      masks: masks,
    });
  }

  private resetBinaryMessageState(): void {
    this.expectedBinaryMessages = 0;
    this.currentBinaryMessages = [];
    this.currentScores = [];
  }

  private async sendMessageWhenConnected(message: any): Promise<void> {
    if (this.isConnected$.value) {
      this.sendMessage(message);
    } else {
      await this.isConnected$
        .pipe(
          filter((isConnected) => isConnected),
          first()
        )
        .toPromise();
      this.sendMessage(message);
    }
  }
}
