mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
add api-auth
This commit is contained in:
@@ -19,8 +19,29 @@ from pydantic import BaseModel, Field
|
|||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
import base64
|
||||||
|
|
||||||
|
class BasicAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app, username: str, password: str):
|
||||||
|
super().__init__(app)
|
||||||
|
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
authorization: str = request.headers.get("Authorization")
|
||||||
|
if authorization:
|
||||||
|
try:
|
||||||
|
schema, credentials = authorization.split()
|
||||||
|
if credentials == self.required_credentials:
|
||||||
|
return await call_next(request)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
headers = {'WWW-Authenticate': 'Basic'}
|
||||||
|
return Response(status_code=401, headers=headers)
|
||||||
|
|
||||||
def _gc(forced: bool = False):
|
def _gc(forced: bool = False):
|
||||||
global args
|
global args
|
||||||
if args.disable_gc and not forced:
|
if args.disable_gc and not forced:
|
||||||
@@ -475,6 +496,9 @@ def _get_args():
|
|||||||
default="Qwen/Qwen-7B-Chat",
|
default="Qwen/Qwen-7B-Chat",
|
||||||
help="Checkpoint name or path, default to %(default)r",
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-auth", help="API authentication credentials"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
||||||
)
|
)
|
||||||
@@ -504,6 +528,11 @@ if __name__ == "__main__":
|
|||||||
resume_download=True,
|
resume_download=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.api_auth:
|
||||||
|
app.add_middleware(
|
||||||
|
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
|
||||||
|
)
|
||||||
|
|
||||||
if args.cpu_only:
|
if args.cpu_only:
|
||||||
device_map = "cpu"
|
device_map = "cpu"
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user