# -*- coding: utf-8 -*- """ 小程序认证路由 —— 微信登录、申请提交、状态查询、店铺切换、令牌刷新。 端点清单: - POST /api/xcx/login — 微信登录(查找/创建用户 + 签发 JWT) - POST /api/xcx/apply — 提交入驻申请 - GET /api/xcx/me — 查询自身状态 + 申请列表 - GET /api/xcx/me/sites — 查询关联店铺 - POST /api/xcx/switch-site — 切换当前店铺 - POST /api/xcx/refresh — 刷新令牌 """ from __future__ import annotations import logging from fastapi import APIRouter, Depends, HTTPException, status from jose import JWTError from psycopg2 import errors as pg_errors from app.auth.dependencies import ( CurrentUser, get_current_user, get_current_user_or_limited, ) from app.auth.jwt import ( create_limited_token_pair, create_token_pair, decode_refresh_token, ) from app.database import get_connection from app.services.application import ( create_application, get_user_applications, ) from app.schemas.xcx_auth import ( ApplicationRequest, ApplicationResponse, RefreshTokenRequest, SiteInfo, SwitchSiteRequest, UserStatusResponse, WxLoginRequest, WxLoginResponse, ) from app.services.wechat import WeChatAuthError, code2session logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/xcx", tags=["小程序认证"]) # ── 辅助:查询用户在指定 site_id 下的角色 code 列表 ────────── def _get_user_roles_at_site(conn, user_id: int, site_id: int) -> list[str]: """查询用户在指定 site_id 下的角色 code 列表。""" with conn.cursor() as cur: cur.execute( """ SELECT r.code FROM auth.user_site_roles usr JOIN auth.roles r ON usr.role_id = r.id WHERE usr.user_id = %s AND usr.site_id = %s """, (user_id, site_id), ) return [row[0] for row in cur.fetchall()] def _get_user_default_site(conn, user_id: int) -> int | None: """获取用户第一个关联的 site_id(按创建时间排序)。""" with conn.cursor() as cur: cur.execute( """ SELECT DISTINCT site_id FROM auth.user_site_roles WHERE user_id = %s ORDER BY site_id LIMIT 1 """, (user_id,), ) row = cur.fetchone() return row[0] if row else None # ── POST /api/xcx/login ────────────────────────────────── @router.post("/login", response_model=WxLoginResponse) async def wx_login(body: WxLoginRequest): """ 微信登录。 流程:code → code2session(openid) → 查找/创建 auth.users → 签发 JWT。 - disabled 用户返回 403 - 新用户自动创建(status=pending) - approved 用户签发包含 site_id + roles 的完整令牌 - pending/rejected 用户签发受限令牌 """ # 1. 调用微信 code2Session try: wx_result = await code2session(body.code) except WeChatAuthError as exc: raise HTTPException(status_code=exc.http_status, detail=exc.detail) except RuntimeError as exc: # 微信配置缺失 logger.error("微信配置错误: %s", exc) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="服务器配置错误", ) openid = wx_result["openid"] unionid = wx_result.get("unionid") # 2. 查找/创建用户 conn = get_connection() try: with conn.cursor() as cur: cur.execute( "SELECT id, status FROM auth.users WHERE wx_openid = %s", (openid,), ) row = cur.fetchone() if row is None: # 新用户:创建 pending 记录 try: cur.execute( """ INSERT INTO auth.users (wx_openid, wx_union_id, status) VALUES (%s, %s, 'pending') RETURNING id, status """, (openid, unionid), ) row = cur.fetchone() conn.commit() except pg_errors.UniqueViolation: # 并发创建:回滚后查询已有记录 conn.rollback() cur.execute( "SELECT id, status FROM auth.users WHERE wx_openid = %s", (openid,), ) row = cur.fetchone() user_id, user_status = row # 3. disabled 用户拒绝登录 if user_status == "disabled": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="账号已被禁用", ) # 4. 签发令牌 if user_status == "approved": # 查找默认 site_id 和角色 default_site_id = _get_user_default_site(conn, user_id) if default_site_id is not None: roles = _get_user_roles_at_site(conn, user_id, default_site_id) tokens = create_token_pair(user_id, default_site_id, roles=roles) else: # approved 但无 site 绑定(异常边界),签发受限令牌 tokens = create_limited_token_pair(user_id) else: # pending / rejected → 受限令牌 tokens = create_limited_token_pair(user_id) finally: conn.close() return WxLoginResponse( access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type=tokens["token_type"], user_status=user_status, user_id=user_id, ) # ── POST /api/xcx/apply ────────────────────────────────── @router.post("/apply", response_model=ApplicationResponse) async def submit_application( body: ApplicationRequest, user: CurrentUser = Depends(get_current_user_or_limited), ): """ 提交入驻申请。 委托 application service 处理: 检查重复 pending → site_code 映射 → 创建记录 → 更新 nickname。 """ result = await create_application( user_id=user.user_id, site_code=body.site_code, applied_role_text=body.applied_role_text, phone=body.phone, employee_number=body.employee_number, nickname=body.nickname, ) return ApplicationResponse(**result) # ── GET /api/xcx/me ─────────────────────────────────────── @router.get("/me", response_model=UserStatusResponse) async def get_my_status( user: CurrentUser = Depends(get_current_user_or_limited), ): """ 查询自身状态 + 所有申请记录。 pending / approved / rejected 用户均可访问。 """ conn = get_connection() try: with conn.cursor() as cur: # 查询用户基本信息 cur.execute( "SELECT id, status, nickname FROM auth.users WHERE id = %s", (user.user_id,), ) user_row = cur.fetchone() if user_row is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) user_id, user_status, nickname = user_row finally: conn.close() # 委托 service 查询申请列表 app_list = await get_user_applications(user_id) applications = [ApplicationResponse(**a) for a in app_list] return UserStatusResponse( user_id=user_id, status=user_status, nickname=nickname, applications=applications, ) # ── GET /api/xcx/me/sites ──────────────────────────────── @router.get("/me/sites", response_model=list[SiteInfo]) async def get_my_sites( user: CurrentUser = Depends(get_current_user), ): """ 查询当前用户关联的所有店铺及对应角色。 仅 approved 用户可访问(通过 get_current_user 依赖保证)。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT usr.site_id, COALESCE(scm.site_name, '未知店铺') AS site_name, r.code AS role_code, r.name AS role_name FROM auth.user_site_roles usr JOIN auth.roles r ON usr.role_id = r.id LEFT JOIN auth.site_code_mapping scm ON usr.site_id = scm.site_id WHERE usr.user_id = %s ORDER BY usr.site_id, r.code """, (user.user_id,), ) rows = cur.fetchall() finally: conn.close() # 按 site_id 分组 sites_map: dict[int, SiteInfo] = {} for site_id, site_name, role_code, role_name in rows: if site_id not in sites_map: sites_map[site_id] = SiteInfo( site_id=site_id, site_name=site_name, roles=[] ) sites_map[site_id].roles.append({"code": role_code, "name": role_name}) return list(sites_map.values()) # ── POST /api/xcx/switch-site ──────────────────────────── @router.post("/switch-site", response_model=WxLoginResponse) async def switch_site( body: SwitchSiteRequest, user: CurrentUser = Depends(get_current_user), ): """ 切换当前店铺。 验证用户在目标 site_id 下有角色绑定,然后签发包含新 site_id 的 JWT。 """ conn = get_connection() try: with conn.cursor() as cur: # 验证用户在目标 site_id 下有角色 cur.execute( """ SELECT 1 FROM auth.user_site_roles WHERE user_id = %s AND site_id = %s LIMIT 1 """, (user.user_id, body.site_id), ) if cur.fetchone() is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="您在该店铺下没有角色绑定", ) roles = _get_user_roles_at_site(conn, user.user_id, body.site_id) # 查询用户状态 cur.execute( "SELECT status FROM auth.users WHERE id = %s", (user.user_id,), ) user_row = cur.fetchone() user_status = user_row[0] if user_row else "pending" finally: conn.close() tokens = create_token_pair(user.user_id, body.site_id, roles=roles) return WxLoginResponse( access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type=tokens["token_type"], user_status=user_status, user_id=user.user_id, ) # ── POST /api/xcx/refresh ──────────────────────────────── @router.post("/refresh", response_model=WxLoginResponse) async def refresh_token(body: RefreshTokenRequest): """ 刷新令牌。 解码 refresh_token → 根据用户当前状态签发新的令牌对。 - 受限 refresh_token(limited=True)→ 签发新的受限令牌对 - 完整 refresh_token → 签发新的完整令牌对(保持原 site_id) """ try: payload = decode_refresh_token(body.refresh_token) except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌", ) user_id = int(payload["sub"]) is_limited = payload.get("limited", False) conn = get_connection() try: with conn.cursor() as cur: # 查询用户当前状态 cur.execute( "SELECT id, status FROM auth.users WHERE id = %s", (user_id,), ) user_row = cur.fetchone() if user_row is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) _, user_status = user_row if user_status == "disabled": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="账号已被禁用", ) if is_limited or user_status != "approved": # 受限令牌刷新 → 仍签发受限令牌 tokens = create_limited_token_pair(user_id) else: # 完整令牌刷新 → 使用原 site_id 签发 site_id = payload.get("site_id") if site_id is None: # 回退到默认 site site_id = _get_user_default_site(conn, user_id) if site_id is not None: roles = _get_user_roles_at_site(conn, user_id, site_id) tokens = create_token_pair(user_id, site_id, roles=roles) else: tokens = create_limited_token_pair(user_id) finally: conn.close() return WxLoginResponse( access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type=tokens["token_type"], user_status=user_status, user_id=user_id, )