from typing import List
from datetime import datetime
from decimal import Decimal
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from .. import models, schemas
from ..database import get_db
from ..deps import get_current_user

router = APIRouter()


def _compute_grade(score: Decimal, total: int) -> Decimal:
    """Convert raw score (0-total) to Romanian nota (1.00-10.00)."""
    if total == 0:
        return Decimal("1.00")
    pct = float(score) / float(total)
    grade = 1.0 + pct * 9.0
    return Decimal(str(round(min(max(grade, 1.0), 10.0), 2)))


@router.post("/", response_model=schemas.SessionSummary, status_code=201)
def start_session(
    body: schemas.StartSession,
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    exam = db.query(models.Exam).filter(
        models.Exam.id == str(body.exam_id),
        models.Exam.is_published == True,
    ).first()
    if not exam:
        raise HTTPException(status_code=404, detail="Examen negăsit")

    # Prevent duplicate in-progress sessions for same exam
    existing = db.query(models.ExamSession).filter(
        models.ExamSession.user_id == user.id,
        models.ExamSession.exam_id == str(body.exam_id),
        models.ExamSession.status == "in_progress",
    ).first()
    if existing:
        return _session_summary(existing)

    session = models.ExamSession(
        user_id=user.id,
        exam_id=str(body.exam_id),
        total_points=exam.total_points,
    )
    db.add(session)
    db.flush()

    # Pre-create empty answer slots for every question
    for question in exam.questions:
        answer = models.Answer(
            session_id=session.id,
            question_id=question.id,
            points_earned=Decimal("0"),
        )
        db.add(answer)

    db.commit()
    db.refresh(session)
    return _session_summary(session)


@router.get("/", response_model=List[schemas.SessionSummary])
def list_sessions(
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    sessions = (
        db.query(models.ExamSession)
        .filter(models.ExamSession.user_id == user.id)
        .order_by(models.ExamSession.started_at.desc())
        .all()
    )
    return [_session_summary(s) for s in sessions]


@router.get("/{session_id}", response_model=schemas.SessionSummary)
def get_session(
    session_id: str,
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    session = _get_session_for_user(session_id, user.id, db)
    return _session_summary(session)


@router.put("/{session_id}/answers/{question_id}")
def save_answer(
    session_id: str,
    question_id: str,
    body: schemas.SaveAnswer,
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    session = _get_session_for_user(session_id, user.id, db)
    if session.status != "in_progress":
        raise HTTPException(status_code=400, detail="Sesiunea a fost deja finalizată")

    answer = db.query(models.Answer).filter(
        models.Answer.session_id == session_id,
        models.Answer.question_id == question_id,
    ).first()
    if not answer:
        raise HTTPException(status_code=404, detail="Întrebarea nu aparține acestei sesiuni")

    answer.selected_option_id = str(body.selected_option_id) if body.selected_option_id else None
    answer.answer_text = body.answer_text
    answer.answered_at = datetime.utcnow()
    db.commit()
    return {"ok": True}


@router.post("/{session_id}/submit", response_model=schemas.SessionResult)
def submit_session(
    session_id: str,
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    session = _get_session_for_user(session_id, user.id, db)
    if session.status != "in_progress":
        raise HTTPException(status_code=400, detail="Sesiunea a fost deja finalizată")

    now = datetime.utcnow()
    time_spent = int((now - session.started_at).total_seconds())
    total_score = Decimal("0")
    answer_results = []

    for answer in session.answers:
        question = answer.question
        options_with_correct = [
            schemas.OptionWithCorrect(
                id=o.id,
                option_label=o.option_label,
                option_text=o.option_text,
                is_correct=o.is_correct,
            )
            for o in question.options
        ]

        if question.question_type == "multiple_choice" and answer.selected_option_id:
            selected_opt = db.query(models.QuestionOption).filter(
                models.QuestionOption.id == str(answer.selected_option_id)
            ).first()
            if selected_opt and selected_opt.is_correct:
                answer.is_correct = True
                answer.points_earned = question.points
            else:
                answer.is_correct = False
                answer.points_earned = Decimal("0")
        elif question.question_type == "true_false" and answer.selected_option_id:
            selected_opt = db.query(models.QuestionOption).filter(
                models.QuestionOption.id == str(answer.selected_option_id)
            ).first()
            if selected_opt and selected_opt.is_correct:
                answer.is_correct = True
                answer.points_earned = question.points
            else:
                answer.is_correct = False
                answer.points_earned = Decimal("0")
        else:
            # open_ended: no auto-grade (keep points_earned=0, is_correct=None)
            answer.is_correct = None

        total_score += answer.points_earned

        answer_results.append(schemas.AnswerResult(
            question_id=question.id,
            question_number=question.question_number,
            question_text=question.question_text,
            question_type=question.question_type,
            points=question.points,
            selected_option_id=answer.selected_option_id,
            answer_text=answer.answer_text,
            is_correct=answer.is_correct,
            points_earned=answer.points_earned,
            options=options_with_correct,
        ))

    grade = _compute_grade(total_score, session.total_points or 100)

    session.submitted_at = now
    session.time_spent_seconds = time_spent
    session.status = "submitted"
    session.total_score = total_score
    session.grade = grade
    session.passed = grade >= Decimal("5.00")
    db.commit()

    return schemas.SessionResult(
        id=session.id,
        exam_id=session.exam_id,
        exam_title=session.exam.title,
        subject_name=session.exam.subject.name,
        exam_type_code=session.exam.exam_type.code,
        started_at=session.started_at,
        submitted_at=session.submitted_at,
        time_spent_seconds=session.time_spent_seconds,
        status=session.status,
        total_score=total_score,
        total_points=session.total_points or 100,
        grade=grade,
        passed=session.passed,
        answers=answer_results,
    )


@router.get("/{session_id}/result", response_model=schemas.SessionResult)
def get_result(
    session_id: str,
    db: Session = Depends(get_db),
    user: models.User = Depends(get_current_user),
):
    session = _get_session_for_user(session_id, user.id, db)
    if session.status == "in_progress":
        raise HTTPException(status_code=400, detail="Sesiunea nu a fost finalizată")

    answer_results = []
    for answer in session.answers:
        question = answer.question
        options_with_correct = [
            schemas.OptionWithCorrect(
                id=o.id,
                option_label=o.option_label,
                option_text=o.option_text,
                is_correct=o.is_correct,
            )
            for o in question.options
        ]
        answer_results.append(schemas.AnswerResult(
            question_id=question.id,
            question_number=question.question_number,
            question_text=question.question_text,
            question_type=question.question_type,
            points=question.points,
            selected_option_id=answer.selected_option_id,
            answer_text=answer.answer_text,
            is_correct=answer.is_correct,
            points_earned=answer.points_earned,
            options=options_with_correct,
        ))

    return schemas.SessionResult(
        id=session.id,
        exam_id=session.exam_id,
        exam_title=session.exam.title,
        subject_name=session.exam.subject.name,
        exam_type_code=session.exam.exam_type.code,
        started_at=session.started_at,
        submitted_at=session.submitted_at,
        time_spent_seconds=session.time_spent_seconds,
        status=session.status,
        total_score=session.total_score or Decimal("0"),
        total_points=session.total_points or 100,
        grade=session.grade or Decimal("1"),
        passed=session.passed or False,
        answers=answer_results,
    )


# ── helpers ────────────────────────────────────────────────────────────────────

def _get_session_for_user(session_id: str, user_id, db: Session) -> models.ExamSession:
    session = db.query(models.ExamSession).filter(
        models.ExamSession.id == session_id,
        models.ExamSession.user_id == user_id,
    ).first()
    if not session:
        raise HTTPException(status_code=404, detail="Sesiune negăsită")
    return session


def _session_summary(s: models.ExamSession) -> schemas.SessionSummary:
    return schemas.SessionSummary(
        id=s.id,
        exam_id=s.exam_id,
        exam_title=s.exam.title,
        subject_name=s.exam.subject.name,
        exam_type_code=s.exam.exam_type.code,
        started_at=s.started_at,
        submitted_at=s.submitted_at,
        status=s.status,
        total_score=s.total_score,
        total_points=s.total_points,
        grade=s.grade,
        passed=s.passed,
    )
