Skip to content
Snippets Groups Projects
Verified Commit 2fcd793a authored by Mario Oenning's avatar Mario Oenning
Browse files

fix(test): Fixes for "TestClient" changes

Seems that client is optional according to the ASGI spec.
https://asgi.readthedocs.io/en/latest/specs/www.html

With Starlette 0.35 the TestClient connection  scope is None for "client".
https://github.com/encode/starlette/pull/2377



Signed-off-by: default avatarmoson <moson@archlinux.org>
parent 22e15773
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ from fastapi import Request
from aurweb import db, schema
from aurweb.models.declarative import Base
from aurweb.util import get_client_ip
class Ban(Base):
......@@ -14,6 +15,6 @@ class Ban(Base):
def is_banned(request: Request):
ip = request.client.host
ip = get_client_ip(request)
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()
......@@ -122,7 +122,7 @@ class User(Base):
try:
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
self.LastLoginIPAddress = util.get_client_ip(request)
if not self.session:
sid = generate_unique_sid()
self.session = db.create(
......
......@@ -4,6 +4,7 @@ from redis.client import Pipeline
from aurweb import aur_logging, config, db, time
from aurweb.aur_redis import redis_connection
from aurweb.models import ApiRateLimit
from aurweb.util import get_client_ip
logger = aur_logging.get_logger(__name__)
......@@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
now = time.utcnow()
time_to_delete = now - window_length
host = request.client.host
host = get_client_ip(request)
window_key = f"ratelimit-ws:{host}"
requests_key = f"ratelimit:{host}"
......@@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request):
record.Requests += 1
return record
host = request.client.host
host = get_client_ip(request)
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host)
......@@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
record = update_ratelimit(request, pipeline)
# Get cache value, else None.
host = request.client.host
host = get_client_ip(request)
pipeline.get(f"ratelimit:{host}")
requests = pipeline.execute()[0]
......
......@@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
conn.execute(
Users.update()
.where(Users.c.ID == user_id)
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
.values(
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
)
)
return sid
......@@ -110,7 +112,7 @@ async def authenticate(
Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session.
"""
if is_ip_banned(conn, request.client.host):
if is_ip_banned(conn, util.get_client_ip(request)):
_ = get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
......
......@@ -67,7 +67,7 @@ def invalid_password(
def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host
host = util.get_client_ip(request)
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar():
raise ValidationError(
......
......@@ -208,3 +208,11 @@ def hash_query(query: Query):
return sha1(
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
).hexdigest()
def get_client_ip(request: fastapi.Request) -> str:
"""
Returns the client's IP address for a Request.
Falls back to 'no-client' is request.client is None
"""
return request.client.host if request.client else "no-client"
......@@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):
def test_post_register_error_ip_banned(client: TestClient):
# 'testclient' is used as request.client.host via FastAPI TestClient.
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())
with client as request:
response = post_register(request)
......
......@@ -310,10 +310,10 @@ def pipeline():
redis = redis_connection()
pipeline = redis.pipeline()
# The 'testclient' host is used when requesting the app
# via fastapi.testclient.TestClient.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
pipeline.execute()
yield pipeline
......@@ -760,8 +760,8 @@ def test_rpc_ratelimit(
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
# Delete the cached records.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
one, two = pipeline.execute()
assert one and two
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment