From 232a86e11f364ec5f1e7807b6519ee62e1361908 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 24 Apr 2026 13:11:09 +0800 Subject: [PATCH] =?UTF-8?q?reafactor:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- luxx/routes/providers.py | 220 ++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 119 deletions(-) diff --git a/luxx/routes/providers.py b/luxx/routes/providers.py index fddbb5f..f3f16bf 100644 --- a/luxx/routes/providers.py +++ b/luxx/routes/providers.py @@ -2,8 +2,9 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel +from sqlalchemy.orm import Session -from luxx.database import get_db, SessionLocal +from luxx.database import get_db from luxx.models import User, LLMProvider from luxx.routes.auth import get_current_user from luxx.utils.helpers import success_response @@ -37,32 +38,28 @@ class ProviderUpdate(BaseModel): @router.get("/", response_model=dict) def list_providers( - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) ): """Get user's LLM providers""" - db = SessionLocal() - try: - providers = db.query(LLMProvider).filter( - LLMProvider.user_id == current_user.id - ).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all() - - return success_response(data={ - "providers": [p.to_dict() for p in providers], - "total": len(providers) - }) - finally: - db.close() + providers = db.query(LLMProvider).filter( + LLMProvider.user_id == current_user.id + ).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all() + + return success_response(data={ + "providers": [p.to_dict() for p in providers], + "total": len(providers) + }) @router.post("/", response_model=dict) def create_provider( provider: ProviderCreate, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) ): """Create a new LLM provider""" - db = SessionLocal() try: - db_provider = LLMProvider( user_id=current_user.id, name=provider.name, @@ -80,83 +77,96 @@ def create_provider( except Exception as e: db.rollback() raise HTTPException(status_code=400, detail=str(e)) - finally: - db.close() @router.get("/{provider_id}", response_model=dict) def get_provider( provider_id: int, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) ): """Get provider details""" - db = SessionLocal() - try: - provider = db.query(LLMProvider).filter( - LLMProvider.id == provider_id, - LLMProvider.user_id == current_user.id - ).first() - - if not provider: - raise HTTPException(status_code=404, detail="Provider not found") - - return success_response(data=provider.to_dict(include_key=True)) - finally: - db.close() + provider = db.query(LLMProvider).filter( + LLMProvider.id == provider_id, + LLMProvider.user_id == current_user.id + ).first() + + if not provider: + raise HTTPException(status_code=404, detail="Provider not found") + + return success_response(data=provider.to_dict(include_key=True)) @router.put("/{provider_id}", response_model=dict) def update_provider( provider_id: int, update: ProviderUpdate, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) ): """Update provider""" - db = SessionLocal() + provider = db.query(LLMProvider).filter( + LLMProvider.id == provider_id, + LLMProvider.user_id == current_user.id + ).first() + + if not provider: + raise HTTPException(status_code=404, detail="Provider not found") + + # If setting as default, unset others + if update.is_default: + db.query(LLMProvider).filter( + LLMProvider.user_id == current_user.id, + LLMProvider.id != provider_id + ).update({"is_default": False}) + + # Update fields + update_data = update.dict(exclude_unset=True) + # Keep existing API key if the new one is empty + if update_data.get('api_key') == '': + update_data.pop('api_key') + for key, value in update_data.items(): + setattr(provider, key, value) + try: - provider = db.query(LLMProvider).filter( - LLMProvider.id == provider_id, - LLMProvider.user_id == current_user.id - ).first() - - if not provider: - raise HTTPException(status_code=404, detail="Provider not found") - - # If setting as default, unset others - if update.is_default: - db.query(LLMProvider).filter( - LLMProvider.user_id == current_user.id, - LLMProvider.id != provider_id - ).update({"is_default": False}) - - # Update fields - update_data = update.dict(exclude_unset=True) - # Keep existing API key if the new one is empty - if update_data.get('api_key') == '': - update_data.pop('api_key') - for key, value in update_data.items(): - setattr(provider, key, value) - db.commit() db.refresh(provider) - return success_response(data=provider.to_dict(include_key=True)) except HTTPException: raise except Exception as e: db.rollback() raise HTTPException(status_code=400, detail=str(e)) - finally: - db.close() @router.delete("/{provider_id}", response_model=dict) def delete_provider( provider_id: int, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) ): """Delete provider""" - db = SessionLocal() + provider = db.query(LLMProvider).filter( + LLMProvider.id == provider_id, + LLMProvider.user_id == current_user.id + ).first() + + if not provider: + raise HTTPException(status_code=404, detail="Provider not found") + + db.delete(provider) + db.commit() + + return success_response(message="Provider deleted") + + +@router.post("/{provider_id}/test") +def test_provider( + provider_id: int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Test provider connection""" try: provider = db.query(LLMProvider).filter( LLMProvider.id == provider_id, @@ -164,63 +174,35 @@ def delete_provider( ).first() if not provider: - raise HTTPException(status_code=404, detail="Provider not found") + return {"success": False, "message": "Provider not found"} - db.delete(provider) - db.commit() - - return success_response(message="Provider deleted") - finally: - db.close() - - -@router.post("/{provider_id}/test") -def test_provider( - provider_id: int, - current_user: User = Depends(get_current_user) -): - """Test provider connection""" - - try: - db = SessionLocal() - try: - provider = db.query(LLMProvider).filter( - LLMProvider.id == provider_id, - LLMProvider.user_id == current_user.id - ).first() - - if not provider: - return {"success": False, "message": "Provider not found"} - - # Test the connection - async def test(): - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.post( - provider.base_url, - headers={ - "Authorization": f"Bearer {provider.api_key}", - "Content-Type": "application/json" - }, - json={ - "model": provider.default_model, - "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}], - "stream": False - } - ) - response.raise_for_status() - return { - "status_code": response.status_code, - "success": True, - "response_body": response.text[:500] if response.text else None + # Test the connection + async def test(): + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.post( + provider.base_url, + headers={ + "Authorization": f"Bearer {provider.api_key}", + "Content-Type": "application/json" + }, + json={ + "model": provider.default_model, + "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}], + "stream": False } - - result = asyncio.run(test()) - return { - "success": result.get("success", False), - "message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"), - "data": result - } - finally: - db.close() + ) + response.raise_for_status() + return { + "status_code": response.status_code, + "success": True, + "response_body": response.text[:500] if response.text else None + } + + result = asyncio.run(test()) + return { + "success": result.get("success", False), + "message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"), + "data": result + } except Exception as e: return {"success": False, "message": f"Exception {str(e)}", "error_type": type(e).__name__, "traceback": traceback.format_exc()[:500]}