add api-auth

This commit is contained in:
曾金栋
2023-11-09 17:18:47 +08:00
parent 99cacff46a
commit ee4b20f2fa

View File

@@ -19,7 +19,28 @@ from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False): def _gc(forced: bool = False):
global args global args
@@ -475,6 +496,9 @@ def _get_args():
default="Qwen/Qwen-7B-Chat", default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r", help="Checkpoint name or path, default to %(default)r",
) )
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument( parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only" "--cpu-only", action="store_true", help="Run demo with CPU only"
) )
@@ -504,6 +528,11 @@ if __name__ == "__main__":
resume_download=True, resume_download=True,
) )
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
if args.cpu_only: if args.cpu_only:
device_map = "cpu" device_map = "cpu"
else: else: