from flask import Blueprint, render_template, request, jsonify, flash, redirect, url_for, make_response
from flask_login import login_required, current_user
from src.models.user import db, Empresa, Resposta, Pergunta
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment, PatternFill
from openpyxl.utils import get_column_letter

visualizacao_bp = Blueprint('visualizacao', __name__)

@visualizacao_bp.route('/visualizacao/empresa/<int:empresa_id>')
@login_required
def visualizar_empresa(empresa_id):
    # Verificar se o usuário tem permissão para visualizar esta empresa
    if current_user.tipo_perfil == 'cliente' and current_user.empresa_id != empresa_id:
        return render_template('errors/403.html'), 403
    
    # Obter empresa
    empresa = Empresa.query.get_or_404(empresa_id)
    
    # Obter respostas da empresa
    respostas = Resposta.query.filter_by(empresa_id=empresa_id).all()
    
    # Se não houver respostas, redirecionar para página apropriada
    if not respostas:
        flash('Não há dados disponíveis para visualização.', 'warning')
        if current_user.tipo_perfil == 'consultor':
            return redirect(url_for('consultor_pesquisa.listar_pesquisas'))
        else:
            return redirect(url_for('cliente.dashboard'))
    
    # Preparar dados para visualização
    dados_visualizacao = preparar_dados_visualizacao(empresa_id)
    
    # Gerar gráficos
    graficos = gerar_graficos(dados_visualizacao)
    
    return render_template('visualizacao/dashboard.html', 
                          empresa=empresa, 
                          dados=dados_visualizacao,
                          graficos=graficos)

@visualizacao_bp.route('/visualizacao/empresa/<int:empresa_id>/comparativo')
@login_required
def visualizar_comparativo(empresa_id):
    # Verificar se o usuário tem permissão para visualizar esta empresa
    if current_user.tipo_perfil == 'cliente' and current_user.empresa_id != empresa_id:
        return render_template('errors/403.html'), 403
    
    # Obter empresa
    empresa = Empresa.query.get_or_404(empresa_id)
    
    # Obter dados comparativos
    dados_comparativos = preparar_dados_comparativos(empresa_id)
    
    # Gerar gráficos comparativos
    graficos_comparativos = gerar_graficos_comparativos(dados_comparativos, empresa)
    
    return render_template('visualizacao/comparativo.html', 
                          empresa=empresa, 
                          dados=dados_comparativos,
                          graficos=graficos_comparativos)

@visualizacao_bp.route('/visualizacao/empresa/<int:empresa_id>/exportar/pdf')
@login_required
def exportar_pdf(empresa_id):
    # Verificar se o usuário tem permissão para visualizar esta empresa
    if current_user.tipo_perfil == 'cliente' and current_user.empresa_id != empresa_id:
        return render_template('errors/403.html'), 403
    
    # Obter empresa
    empresa = Empresa.query.get_or_404(empresa_id)
    
    # Preparar dados para visualização
    dados_visualizacao = preparar_dados_visualizacao(empresa_id)
    
    # Gerar PDF
    pdf_buffer = gerar_pdf(empresa, dados_visualizacao)
    
    # Criar resposta com o PDF
    response = make_response(pdf_buffer.getvalue())
    response.headers['Content-Type'] = 'application/pdf'
    response.headers['Content-Disposition'] = f'attachment; filename=relatorio_{empresa.nome_fantasia.replace(" ", "_")}.pdf'
    
    return response

@visualizacao_bp.route('/visualizacao/empresa/<int:empresa_id>/exportar/excel')
@login_required
def exportar_excel(empresa_id):
    # Verificar se o usuário tem permissão para visualizar esta empresa
    if current_user.tipo_perfil == 'cliente' and current_user.empresa_id != empresa_id:
        return render_template('errors/403.html'), 403
    
    # Obter empresa
    empresa = Empresa.query.get_or_404(empresa_id)
    
    # Preparar dados para visualização
    dados_visualizacao = preparar_dados_visualizacao(empresa_id)
    
    # Gerar Excel
    excel_buffer = gerar_excel(empresa, dados_visualizacao)
    
    # Criar resposta com o Excel
    response = make_response(excel_buffer.getvalue())
    response.headers['Content-Type'] = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
    response.headers['Content-Disposition'] = f'attachment; filename=dados_{empresa.nome_fantasia.replace(" ", "_")}.xlsx'
    
    return response

# Funções auxiliares
def preparar_dados_visualizacao(empresa_id):
    # Obter todas as respostas da empresa
    respostas = Resposta.query.filter_by(empresa_id=empresa_id).all()
    
    # Organizar dados por setor, assunto e atividade
    dados = {}
    
    for resposta in respostas:
        pergunta = Pergunta.query.get(resposta.pergunta_id)
        
        if not pergunta:
            continue
        
        setor = pergunta.setor.descricao
        assunto = pergunta.assunto.descricao
        atividade = pergunta.atividade.descricao
        
        if setor not in dados:
            dados[setor] = {}
        
        if assunto not in dados[setor]:
            dados[setor][assunto] = {}
        
        if atividade not in dados[setor][assunto]:
            dados[setor][assunto][atividade] = []
        
        dados[setor][assunto][atividade].append({
            'pergunta': pergunta.texto,
            'tipo': pergunta.tipo_resposta,
            'resposta': resposta.valor
        })
    
    return dados

def preparar_dados_comparativos(empresa_id):
    # Obter empresa atual
    empresa = Empresa.query.get(empresa_id)
    
    # Obter empresas do mesmo segmento e faixa de faturamento
    empresas_similares = Empresa.query.filter(
        Empresa.id != empresa_id,
        Empresa.segmento_id == empresa.segmento_id,
        Empresa.faixa_faturamento_id == empresa.faixa_faturamento_id,
        Empresa.ativo == True
    ).all()
    
    # Obter perguntas numéricas respondidas pela empresa atual
    perguntas_numericas = db.session.query(Pergunta).\
        join(Resposta, Resposta.pergunta_id == Pergunta.id).\
        filter(Resposta.empresa_id == empresa_id, Pergunta.tipo_resposta == 'numero').\
        all()
    
    dados_comparativos = {}
    
    for pergunta in perguntas_numericas:
        # Obter resposta da empresa atual
        resposta_atual = Resposta.query.filter_by(
            empresa_id=empresa_id,
            pergunta_id=pergunta.id
        ).first()
        
        if not resposta_atual:
            continue
        
        # Obter respostas das empresas similares
        respostas_similares = Resposta.query.filter(
            Resposta.empresa_id.in_([e.id for e in empresas_similares]),
            Resposta.pergunta_id == pergunta.id
        ).all()
        
        # Calcular média das empresas similares
        valores_similares = [float(r.valor) for r in respostas_similares if r.valor.replace('.', '').isdigit()]
        
        if valores_similares:
            media_similares = sum(valores_similares) / len(valores_similares)
        else:
            media_similares = 0
        
        # Adicionar aos dados comparativos
        dados_comparativos[pergunta.texto] = {
            'valor_empresa': float(resposta_atual.valor) if resposta_atual.valor.replace('.', '').isdigit() else 0,
            'media_similares': media_similares,
            'setor': pergunta.setor.descricao,
            'assunto': pergunta.assunto.descricao
        }
    
    return dados_comparativos

def gerar_graficos(dados):
    graficos = {}
    
    # Gráfico de barras por setor
    plt.figure(figsize=(10, 6))
    setores = list(dados.keys())
    contagem_perguntas = [sum(len(atividade) for assunto in dados[setor].values() for atividade in assunto.values()) for setor in setores]
    
    plt.bar(setores, contagem_perguntas, color='skyblue')
    plt.xlabel('Setores')
    plt.ylabel('Número de Respostas')
    plt.title('Respostas por Setor')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    graficos['setores'] = base64.b64encode(buffer.getvalue()).decode('utf-8')
    plt.close()
    
    # Gráfico de pizza por assunto
    plt.figure(figsize=(10, 6))
    assuntos = {}
    for setor in dados:
        for assunto in dados[setor]:
            if assunto not in assuntos:
                assuntos[assunto] = 0
            assuntos[assunto] += sum(len(atividade) for atividade in dados[setor][assunto].values())
    
    plt.pie(assuntos.values(), labels=assuntos.keys(), autopct='%1.1f%%', startangle=90, shadow=True)
    plt.axis('equal')
    plt.title('Distribuição de Respostas por Assunto')
    plt.tight_layout()
    
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    graficos['assuntos'] = base64.b64encode(buffer.getvalue()).decode('utf-8')
    plt.close()
    
    return graficos

def gerar_graficos_comparativos(dados_comparativos, empresa):
    graficos = {}
    
    if not dados_comparativos:
        return graficos
    
    # Gráfico de barras comparativo
    plt.figure(figsize=(12, 8))
    perguntas = list(dados_comparativos.keys())
    valores_empresa = [dados_comparativos[p]['valor_empresa'] for p in perguntas]
    valores_similares = [dados_comparativos[p]['media_similares'] for p in perguntas]
    
    x = range(len(perguntas))
    width = 0.35
    
    plt.bar([i - width/2 for i in x], valores_empresa, width, label=f'{empresa.nome_fantasia}')
    plt.bar([i + width/2 for i in x], valores_similares, width, label='Média do Segmento')
    
    plt.xlabel('Perguntas')
    plt.ylabel('Valores')
    plt.title('Comparativo de Respostas')
    plt.xticks(x, [p[:30] + '...' if len(p) > 30 else p for p in perguntas], rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    graficos['comparativo_barras'] = base64.b64encode(buffer.getvalue()).decode('utf-8')
    plt.close()
    
    # Gráfico de radar comparativo
    categorias = list(set([dados_comparativos[p]['assunto'] for p in perguntas]))
    
    if len(categorias) >= 3:  # Radar chart precisa de pelo menos 3 categorias
        plt.figure(figsize=(10, 10))
        
        # Preparar dados por categoria
        valores_por_categoria_empresa = {}
        valores_por_categoria_similares = {}
        
        for categoria in categorias:
            valores_empresa_cat = [dados_comparativos[p]['valor_empresa'] for p in perguntas if dados_comparativos[p]['assunto'] == categoria]
            valores_similares_cat = [dados_comparativos[p]['media_similares'] for p in perguntas if dados_comparativos[p]['assunto'] == categoria]
            
            if valores_empresa_cat:
                valores_por_categoria_empresa[categoria] = sum(valores_empresa_cat) / len(valores_empresa_cat)
            else:
                valores_por_categoria_empresa[categoria] = 0
                
            if valores_similares_cat:
                valores_por_categoria_similares[categoria] = sum(valores_similares_cat) / len(valores_similares_cat)
            else:
                valores_por_categoria_similares[categoria] = 0
        
        # Criar gráfico radar
        categorias = list(valores_por_categoria_empresa.keys())
        valores_empresa = [valores_por_categoria_empresa[c] for c in categorias]
        valores_similares = [valores_por_categoria_similares[c] for c in categorias]
        
        # Adicionar o primeiro valor no final para fechar o polígono
        categorias_plot = categorias + [categorias[0]]
        valores_empresa_plot = valores_empresa + [valores_empresa[0]]
        valores_similares_plot = valores_similares + [valores_similares[0]]
        
        # Configurar o gráfico radar
        angles = [n / float(len(categorias)) * 2 * 3.14159 for n in range(len(categorias))]
        angles += angles[:1]  # Fechar o círculo
        
        ax = plt.subplot(111, polar=True)
        plt.xticks(angles[:-1], categorias, size=8)
        
        ax.plot(angles, valores_empresa_plot, linewidth=1, linestyle='solid', label=f'{empresa.nome_fantasia}')
        ax.fill(angles, valores_empresa_plot, alpha=0.1)
        
        ax.plot(angles, valores_similares_plot, linewidth=1, linestyle='solid', label='Média do Segmento')
        ax.fill(angles, valores_similares_plot, alpha=0.1)
        
        plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
        plt.title('Comparativo por Assunto')
        
        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)
        graficos['comparativo_radar'] = base64.b64encode(buffer.getvalue()).decode('utf-8')
        plt.close()
    
    return graficos

def gerar_pdf(empresa, dados):
    buffer = io.BytesIO()
    
    with PdfPages(buffer) as pdf:
        # Página de capa
        plt.figure(figsize=(8.5, 11))
        plt.axis('off')
        plt.text(0.5, 0.8, 'Relatório de Pesquisa', fontsize=24, ha='center')
        plt.text(0.5, 0.7, empresa.nome_fantasia, fontsize=18, ha='center')
        plt.text(0.5, 0.65, f'CNPJ: {empresa.cnpj}', fontsize=12, ha='center')
        plt.text(0.5, 0.6, f'Segmento: {empresa.segmento.nome}', fontsize=12, ha='center')
        plt.text(0.5, 0.55, f'Localização: {empresa.cidade}/{empresa.estado}', fontsize=12, ha='center')
        plt.text(0.5, 0.2, f'Data: {pd.Timestamp.now().strftime("%d/%m/%Y")}', fontsize=10, ha='center')
        plt.text(0.5, 0.15, 'Plataforma Vinci', fontsize=10, ha='center')
        pdf.savefig()
        plt.close()
        
        # Página de resumo
        plt.figure(figsize=(8.5, 11))
        plt.axis('off')
        plt.text(0.5, 0.95, 'Resumo da Pesquisa', fontsize=16, ha='center')
        
        y_pos = 0.9
        total_perguntas = 0
        
        for setor in dados:
            plt.text(0.1, y_pos, f'Setor: {setor}', fontsize=12)
            y_pos -= 0.03
            
            for assunto in dados[setor]:
                plt.text(0.15, y_pos, f'Assunto: {assunto}', fontsize=10)
                y_pos -= 0.03
                
                for atividade in dados[setor][assunto]:
                    num_perguntas = len(dados[setor][assunto][atividade])
                    total_perguntas += num_perguntas
                    plt.text(0.2, y_pos, f'Atividade: {atividade} - {num_perguntas} pergunta(s)', fontsize=9)
                    y_pos -= 0.03
                    
                    if y_pos < 0.1:  # Nova página se necessário
                        pdf.savefig()
                        plt.close()
                        plt.figure(figsize=(8.5, 11))
                        plt.axis('off')
                        plt.text(0.5, 0.95, 'Resumo da Pesquisa (continuação)', fontsize=16, ha='center')
                        y_pos = 0.9
        
        plt.text(0.1, y_pos, f'Total de perguntas respondidas: {total_perguntas}', fontsize=12)
        pdf.savefig()
        plt.close()
        
        # Gráficos
        graficos = gerar_graficos(dados)
        
        # Página de gráficos
        plt.figure(figsize=(8.5, 11))
        plt.axis('off')
        plt.text(0.5, 0.95, 'Gráficos de Análise', fontsize=16, ha='center')
        
        # Adicionar gráficos como imagens
        for i, (nome, grafico_base64) in enumerate(graficos.items()):
            grafico_img = base64.b64decode(grafico_base64)
            grafico_buffer = io.BytesIO(grafico_img)
            img = plt.imread(grafico_buffer)
            
            ax = plt.subplot(len(graficos), 1, i+1)
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f'Gráfico: {nome.capitalize()}')
        
        pdf.savefig()
        plt.close()
        
        # Dados detalhados
        plt.figure(figsize=(8.5, 11))
        plt.axis('off')
        plt.text(0.5, 0.95, 'Dados Detalhados', fontsize=16, ha='center')
        
        y_pos = 0.9
        for setor in dados:
            plt.text(0.1, y_pos, f'Setor: {setor}', fontsize=12, weight='bold')
            y_pos -= 0.03
            
            for assunto in dados[setor]:
                plt.text(0.15, y_pos, f'Assunto: {assunto}', fontsize=11, weight='bold')
                y_pos -= 0.03
                
                for atividade in dados[setor][assunto]:
                    plt.text(0.2, y_pos, f'Atividade: {atividade}', fontsize=10, weight='bold')
                    y_pos -= 0.03
                    
                    for item in dados[setor][assunto][atividade]:
                        texto_pergunta = item['pergunta']
                        if len(texto_pergunta) > 60:
                            texto_pergunta = texto_pergunta[:57] + '...'
                        
                        plt.text(0.25, y_pos, f'P: {texto_pergunta}', fontsize=9)
                        y_pos -= 0.02
                        plt.text(0.25, y_pos, f'R: {item["resposta"]}', fontsize=9)
                        y_pos -= 0.03
                        
                        if y_pos < 0.1:  # Nova página se necessário
                            pdf.savefig()
                            plt.close()
                            plt.figure(figsize=(8.5, 11))
                            plt.axis('off')
                            plt.text(0.5, 0.95, 'Dados Detalhados (continuação)', fontsize=16, ha='center')
                            y_pos = 0.9
        
        pdf.savefig()
        plt.close()
    
    buffer.seek(0)
    return buffer

def gerar_excel(empresa, dados):
    buffer = io.BytesIO()
    
    # Criar workbook
    wb = Workbook()
    
    # Planilha de informações
    ws_info = wb.active
    ws_info.title = "Informações"
    
    # Estilo para cabeçalhos
    header_fill = PatternFill(start_color="1F4E78", end_color="1F4E78", fill_type="solid")
    header_font = Font(color="FFFFFF", bold=True)
    
    # Adicionar informações da empresa
    ws_info['A1'] = "Relatório de Pesquisa"
    ws_info['A1'].font = Font(size=16, bold=True)
    ws_info.merge_cells('A1:D1')
    
    ws_info['A3'] = "Empresa:"
    ws_info['B3'] = empresa.nome_fantasia
    ws_info['A4'] = "CNPJ:"
    ws_info['B4'] = empresa.cnpj
    ws_info['A5'] = "Segmento:"
    ws_info['B5'] = empresa.segmento.nome
    ws_info['A6'] = "Localização:"
    ws_info['B6'] = f"{empresa.cidade}/{empresa.estado}"
    ws_info['A7'] = "Data do Relatório:"
    ws_info['B7'] = pd.Timestamp.now().strftime("%d/%m/%Y")
    
    # Planilha de respostas
    ws_respostas = wb.create_sheet("Respostas")
    
    # Cabeçalhos
    headers = ["Setor", "Assunto", "Atividade", "Pergunta", "Tipo de Resposta", "Resposta"]
    for col, header in enumerate(headers, 1):
        cell = ws_respostas.cell(row=1, column=col, value=header)
        cell.fill = header_fill
        cell.font = header_font
    
    # Ajustar largura das colunas
    for col in range(1, len(headers) + 1):
        ws_respostas.column_dimensions[get_column_letter(col)].width = 20
    
    # Adicionar dados
    row = 2
    for setor in dados:
        for assunto in dados[setor]:
            for atividade in dados[setor][assunto]:
                for item in dados[setor][assunto][atividade]:
                    ws_respostas.cell(row=row, column=1, value=setor)
                    ws_respostas.cell(row=row, column=2, value=assunto)
                    ws_respostas.cell(row=row, column=3, value=atividade)
                    ws_respostas.cell(row=row, column=4, value=item['pergunta'])
                    ws_respostas.cell(row=row, column=5, value=item['tipo'])
                    ws_respostas.cell(row=row, column=6, value=item['resposta'])
                    row += 1
    
    # Planilha de resumo
    ws_resumo = wb.create_sheet("Resumo")
    
    # Cabeçalhos
    ws_resumo['A1'] = "Resumo por Setor"
    ws_resumo['A1'].font = Font(size=14, bold=True)
    ws_resumo.merge_cells('A1:B1')
    
    headers = ["Setor", "Quantidade de Respostas"]
    for col, header in enumerate(headers, 1):
        cell = ws_resumo.cell(row=3, column=col, value=header)
        cell.fill = header_fill
        cell.font = header_font
    
    # Adicionar dados de resumo
    row = 4
    for setor in dados:
        count = sum(len(atividade) for assunto in dados[setor].values() for atividade in assunto.values())
        ws_resumo.cell(row=row, column=1, value=setor)
        ws_resumo.cell(row=row, column=2, value=count)
        row += 1
    
    # Ajustar largura das colunas
    for col in range(1, 3):
        ws_resumo.column_dimensions[get_column_letter(col)].width = 25
    
    # Salvar workbook
    wb.save(buffer)
    buffer.seek(0)
    
    return buffer
