Coverage for tests/publisher/endpoint_testing.py: 92%
160 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 22:43 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 22:43 +0000
1from unittest.mock import patch
3import requests
5import pymacaroons
6import responses
7from flask_testing import TestCase
8from webapp.app import create_app
9from webapp.authentication import get_authorization_header
11# Make sure tests fail on stray responses.
12responses.mock.assert_all_requests_are_fired = True
15class BaseTestCases:
16 """
17 This class has a set of test classes that should be inherited by endpoint
18 that have authentication.
20 It is also used to avoid unittest to run this tests file.
21 """
23 class BaseAppTesting(TestCase):
24 def setUp(self, snap_name, api_url, endpoint_url):
25 self.snap_name = snap_name
26 self.api_url = api_url
27 self.endpoint_url = endpoint_url
29 def tearDown(self):
30 responses.reset()
32 def create_app(self):
33 app = create_app(testing=True)
34 app.secret_key = "secret_key"
35 app.config["WTF_CSRF_METHODS"] = []
37 return app
39 def _get_location(self):
40 return "{}".format(self.endpoint_url)
42 def _log_in(self, client):
43 """Emulates test client login in the store.
45 Fill current session with `openid`, `macaroon_root` and
46 `macaroon_discharge`.
48 Return the expected `Authorization` header for further verification
49 in API requests.
50 """
51 # Basic root/discharge macaroons pair.
52 root = pymacaroons.Macaroon("test", "testing", "a_key")
53 root.add_third_party_caveat("3rd", "a_caveat-key", "a_ident")
54 discharge = pymacaroons.Macaroon("3rd", "a_ident", "a_caveat_key")
56 with client.session_transaction() as s:
57 s["publisher"] = {
58 "image": None,
59 "nickname": "Toto",
60 "fullname": "El Toto",
61 "email": "testing@testing.com",
62 "stores": [],
63 }
64 s["macaroon_root"] = root.serialize()
65 s["macaroon_discharge"] = discharge.serialize()
67 return get_authorization_header(
68 root.serialize(), discharge.serialize()
69 )
71 def check_call_by_api_url(self, calls):
72 found = False
73 for called in calls:
74 if self.api_url == called.request.url:
75 found = True
76 self.assertEqual(
77 self.authorization,
78 called.request.headers.get("Authorization"),
79 )
81 assert found
83 class EndpointLoggedOut(BaseAppTesting):
84 def setUp(self, snap_name, endpoint_url, method_endpoint="GET"):
85 self.method_endpoint = method_endpoint
86 super().setUp(snap_name, None, endpoint_url)
88 def test_access_not_logged_in(self):
89 if self.method_endpoint == "GET":
90 response = self.client.get(self.endpoint_url)
91 else:
92 response = self.client.post(self.endpoint_url, data={})
94 self.assertEqual(302, response.status_code)
95 self.assertEqual(
96 "/login?next={}".format(self.endpoint_url),
97 response.location,
98 )
100 class EndpointLoggedIn(BaseAppTesting):
101 def setUp(
102 self,
103 snap_name,
104 endpoint_url,
105 api_url,
106 method_endpoint="GET",
107 method_api="GET",
108 data=None,
109 json=None,
110 ):
111 super().setUp(
112 snap_name=snap_name, api_url=api_url, endpoint_url=endpoint_url
113 )
115 # Stub the blueprint-level "has releases" gate so each test does
116 # not need to mock the extra dashboard call. Tests that exercise
117 # the gate directly can override this in their own setUp.
118 self._release_history_patcher = patch(
119 "webapp.decorators._dashboard.snap_release_history",
120 return_value={"revisions": [{"revision": 1}]},
121 )
122 self._release_history_patcher.start()
123 self.addCleanup(self._release_history_patcher.stop)
125 self.method_endpoint = method_endpoint
126 self.method_api = method_api
127 self.data = data
128 self.json = json
129 self.authorization = self._log_in(self.client)
131 @responses.activate
132 def test_timeout(self):
133 responses.add(
134 responses.Response(
135 method=self.method_api,
136 url=self.api_url,
137 body=requests.exceptions.Timeout(),
138 status=504,
139 )
140 )
142 if self.method_endpoint == "GET":
143 response = self.client.get(self.endpoint_url)
144 else:
145 if self.data:
146 response = self.client.post(
147 self.endpoint_url, data=self.data
148 )
149 else:
150 response = self.client.post(
151 self.endpoint_url, json=self.json
152 )
154 self.check_call_by_api_url(responses.calls)
156 assert response.status_code == 504
158 @responses.activate
159 def test_connection_error(self):
160 responses.add(
161 responses.Response(
162 method=self.method_api,
163 url=self.api_url,
164 body=requests.exceptions.ConnectionError(),
165 status=500,
166 )
167 )
169 if self.method_endpoint == "GET":
170 response = self.client.get(self.endpoint_url)
171 else:
172 if self.data:
173 response = self.client.post(
174 self.endpoint_url, data=self.data
175 )
176 else:
177 response = self.client.post(
178 self.endpoint_url, json=self.json
179 )
181 self.check_call_by_api_url(responses.calls)
183 assert response.status_code == 502
185 @responses.activate
186 def test_broken_json(self):
187 # To test this I return no json from the server, this makes the
188 # call to the function response.json() raise a ValueError exception
189 responses.add(
190 responses.Response(
191 method=self.method_api, url=self.api_url, status=500
192 )
193 )
195 if self.method_endpoint == "GET":
196 response = self.client.get(self.endpoint_url)
197 else:
198 if self.data:
199 response = self.client.post(
200 self.endpoint_url, data=self.data
201 )
202 else:
203 response = self.client.post(
204 self.endpoint_url, json=self.json
205 )
207 self.check_call_by_api_url(responses.calls)
209 assert response.status_code == 502
211 @responses.activate
212 def test_unknown_error(self):
213 responses.add(
214 responses.Response(
215 method=self.method_api,
216 url=self.api_url,
217 json={},
218 status=500,
219 )
220 )
222 if self.method_endpoint == "GET":
223 response = self.client.get(self.endpoint_url)
224 else:
225 if self.data:
226 response = self.client.post(
227 self.endpoint_url, data=self.data
228 )
229 else:
230 response = self.client.post(
231 self.endpoint_url, json=self.json
232 )
234 self.check_call_by_api_url(responses.calls)
236 assert response.status_code == 502
238 @responses.activate
239 def test_expired_macaroon(self):
240 responses.add(
241 responses.Response(
242 method=self.method_api,
243 url=self.api_url,
244 json={},
245 status=401,
246 headers={"WWW-Authenticate": "Macaroon needs_refresh=1"},
247 )
248 )
249 responses.add(
250 responses.POST,
251 "https://login.ubuntu.com/api/v2/tokens/refresh",
252 json={"discharge_macaroon": "macaroon"},
253 status=200,
254 )
256 if self.method_endpoint == "GET":
257 response = self.client.get(self.endpoint_url)
258 else:
259 if self.data:
260 response = self.client.post(
261 self.endpoint_url, data=self.data
262 )
263 else:
264 response = self.client.post(
265 self.endpoint_url, json=self.json
266 )
268 called = responses.calls[len(responses.calls) - 1]
269 self.assertEqual(
270 "https://login.ubuntu.com/api/v2/tokens/refresh",
271 called.request.url,
272 )
274 assert response.status_code == 302
275 assert response.location == self._get_location()
277 class EndpointLoggedInErrorHandling(EndpointLoggedIn):
278 @responses.activate
279 def test_error_4xx(self):
280 payload = {"error_list": []}
281 responses.add(
282 responses.Response(
283 method=self.method_api,
284 url=self.api_url,
285 json=payload,
286 status=400,
287 )
288 )
290 if self.method_endpoint == "GET":
291 response = self.client.get(self.endpoint_url)
292 else:
293 if self.data:
294 response = self.client.post(
295 self.endpoint_url, data=self.data
296 )
297 else:
298 response = self.client.post(
299 self.endpoint_url, json=self.json
300 )
302 self.check_call_by_api_url(responses.calls)
304 assert response.status_code == 502
306 @responses.activate
307 def test_custom_error(self):
308 payload = {
309 "error_list": [
310 {"code": "error-code1"},
311 {"code": "error-code2"},
312 ]
313 }
314 responses.add(
315 responses.Response(
316 method=self.method_api,
317 url=self.api_url,
318 json=payload,
319 status=400,
320 )
321 )
323 if self.method_endpoint == "GET":
324 response = self.client.get(self.endpoint_url)
325 else:
326 if self.data:
327 response = self.client.post(
328 self.endpoint_url, data=self.data
329 )
330 else:
331 response = self.client.post(
332 self.endpoint_url, json=self.json
333 )
335 self.check_call_by_api_url(responses.calls)
337 assert response.status_code == 502
339 @responses.activate
340 def test_account_not_signed_agreement_logged_in(self):
341 payload = {
342 "error_list": [
343 {
344 "code": "user-not-ready",
345 "message": "has not signed agreement",
346 }
347 ]
348 }
349 responses.add(
350 responses.Response(
351 method=self.method_api,
352 url=self.api_url,
353 json=payload,
354 status=403,
355 )
356 )
358 if self.method_endpoint == "GET":
359 response = self.client.get(self.endpoint_url)
360 else:
361 if self.data:
362 response = self.client.post(
363 self.endpoint_url, data=self.data
364 )
365 else:
366 response = self.client.post(
367 self.endpoint_url, json=self.json
368 )
370 self.check_call_by_api_url(responses.calls)
372 self.assertEqual(302, response.status_code)
373 self.assertEqual("/account/agreement", response.location)
375 @responses.activate
376 def test_account_no_username_logged_in(self):
377 payload = {
378 "error_list": [
379 {
380 "code": "user-not-ready",
381 "message": "missing store username",
382 }
383 ]
384 }
385 responses.add(
386 responses.Response(
387 method=self.method_api,
388 url=self.api_url,
389 json=payload,
390 status=403,
391 )
392 )
394 if self.method_endpoint == "GET":
395 response = self.client.get(self.endpoint_url)
396 else:
397 if self.data:
398 response = self.client.post(
399 self.endpoint_url, data=self.data
400 )
401 else:
402 response = self.client.post(
403 self.endpoint_url, json=self.json
404 )
406 self.check_call_by_api_url(responses.calls)
408 self.assertEqual(302, response.status_code)
409 self.assertEqual("/account/username", response.location)