Merge pull request #606 from joindn/api-auth

add api-auth for openai_api.py
This commit is contained in:
Jianxin Ma
2023-12-26 15:25:44 +08:00
committed by GitHub

View File

@@ -19,7 +19,28 @@ from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False): def _gc(forced: bool = False):
global args global args
@@ -482,6 +503,9 @@ def _get_args():
default="Qwen/Qwen-7B-Chat", default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r", help="Checkpoint name or path, default to %(default)r",
) )
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument( parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only" "--cpu-only", action="store_true", help="Run demo with CPU only"
) )
@@ -511,6 +535,11 @@ if __name__ == "__main__":
resume_download=True, resume_download=True,
) )
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
if args.cpu_only: if args.cpu_only:
device_map = "cpu" device_map = "cpu"
else: else: