Skip to content

Commit bb2543b

Browse files
committed
adds basic working authentication / authorization via Spring Security + JWT
1 parent ffe5bc4 commit bb2543b

10 files changed

Lines changed: 167 additions & 105 deletions

File tree

src/main/scala/com/example/scalaspringexperiment/SpringConfig.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,29 @@ package com.example.scalaspringexperiment
22

33
import cats.effect.unsafe.IORuntime
44
import cats.effect.{IO, Resource}
5-
import com.example.scalaspringexperiment.auth.JwtAuthManager
5+
import com.example.scalaspringexperiment.auth.{JwtAuthManager, JwtServerAuthConverter}
66
import com.example.scalaspringexperiment.util.{CirceJsonDecoder, CirceJsonEncoder}
77
import doobie.{DataSourceTransactor, ExecutionContexts}
88
import doobie.util.transactor.Transactor
9-
import org.springframework.context.annotation.{Bean, Configuration}
9+
import org.springframework.context.annotation.{Bean, Configuration, Primary}
10+
import org.springframework.http.HttpStatus
1011
import org.springframework.http.codec.ServerCodecConfigurer
12+
import org.springframework.security.authentication.{DelegatingReactiveAuthenticationManager, UsernamePasswordAuthenticationToken}
1113
import org.springframework.security.config.Customizer
1214
import org.springframework.security.config.annotation.method.configuration.EnableReactiveMethodSecurity
1315
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
14-
import org.springframework.security.config.web.server.ServerHttpSecurity
16+
import org.springframework.security.config.web.server.{SecurityWebFiltersOrder, ServerHttpSecurity}
1517
import org.springframework.security.web.server.SecurityWebFilterChain
16-
import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository
18+
import org.springframework.security.web.server.authentication.AuthenticationWebFilter
19+
import org.springframework.security.web.server.context.{NoOpServerSecurityContextRepository, WebSessionServerSecurityContextRepository}
1720
import org.springframework.web.reactive.config.WebFluxConfigurer
1821

1922
import javax.sql.DataSource
2023

24+
2125
@Configuration
22-
@EnableWebFluxSecurity
23-
@EnableReactiveMethodSecurity
2426
class SpringConfig(
2527
dataSource: DataSource,
26-
jwtAuthManager: JwtAuthManager,
2728
) {
2829

2930
@Bean
@@ -32,19 +33,37 @@ class SpringConfig(
3233
ce <- ExecutionContexts.fixedThreadPool[IO](32) // our connect EC
3334
} yield Transactor.fromDataSource[IO](dataSource, ce)
3435
}
36+
}
3537

38+
@Configuration
39+
@EnableWebFluxSecurity
40+
@EnableReactiveMethodSecurity
41+
class SecurityConfig(
42+
jwtAuthManager: JwtAuthManager,
43+
) {
3644
@Bean
37-
def securityFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain = {
45+
def securityFilterChain(
46+
http: ServerHttpSecurity,
47+
jwtAuthFilter: AuthenticationWebFilter,
48+
): SecurityWebFilterChain = {
3849
http
3950
.cors(Customizer.withDefaults())
4051
.csrf(csrf => csrf.disable())
41-
.securityContextRepository(NoOpServerSecurityContextRepository.getInstance()) // optional, disables session caching
42-
.authorizeExchange(authz =>
43-
authz.anyExchange().permitAll()
44-
)
45-
.authenticationManager(jwtAuthManager)
52+
.authorizeExchange(_.anyExchange().permitAll())
53+
.addFilterAt(jwtAuthFilter, SecurityWebFiltersOrder.AUTHENTICATION)
54+
.securityContextRepository(NoOpServerSecurityContextRepository.getInstance()) // stateless auth
4655
.build()
4756
}
57+
58+
@Bean
59+
def jwtAuthFilter(
60+
jwtAuthManager: JwtAuthManager
61+
): AuthenticationWebFilter = {
62+
val filter = new AuthenticationWebFilter(jwtAuthManager)
63+
filter.setServerAuthenticationConverter(new JwtServerAuthConverter)
64+
filter.setSecurityContextRepository(NoOpServerSecurityContextRepository.getInstance())
65+
filter
66+
}
4867
}
4968

5069
@Configuration(proxyBeanMethods = false)

src/main/scala/com/example/scalaspringexperiment/auth/JwtAuthFilter.scala

Lines changed: 0 additions & 29 deletions
This file was deleted.

src/main/scala/com/example/scalaspringexperiment/auth/JwtAuthManager.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package com.example.scalaspringexperiment.auth
33
import cats.data.EitherT
44
import cats.effect.IO
55
import cats.effect.unsafe.IORuntime
6-
import com.example.scalaspringexperiment.auth.JwtAuthManager.{AuthError, InvalidCredentials, LoginSuccess, RegisterSuccess, UserExists, UserNotFound}
6+
import com.example.scalaspringexperiment.auth.JwtAuthManager.{AuthError, InvalidCredentials, LoginSuccess, ROLE_USER, RegisterSuccess, UserExists, UserNotFound}
77
import com.example.scalaspringexperiment.entity.{Person, RegisteredUser}
88
import com.example.scalaspringexperiment.service.{PersonService, RegisteredUserService}
99
import org.springframework.context.annotation.Lazy
1010
import org.springframework.security.authentication.{ReactiveAuthenticationManager, UsernamePasswordAuthenticationToken}
11+
import org.springframework.security.core.authority.SimpleGrantedAuthority
1112
import org.springframework.security.core.{Authentication, GrantedAuthority}
1213
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
1314
import org.springframework.stereotype.Component
@@ -19,6 +20,8 @@ import java.time.Instant
1920
import scala.jdk.CollectionConverters.*
2021

2122
object JwtAuthManager {
23+
val ROLE_USER = "ROLE_USER"
24+
val ROLE_ADMIN = "ROLE_ADMIN"
2225
sealed trait AuthError {
2326
val message: String
2427
}
@@ -63,18 +66,24 @@ class JwtAuthManager(
6366
private val secretKey: String = "secretKey"
6467
private val algo = JwtAlgorithm.HS256
6568

66-
override def authenticate(
67-
authentication: Authentication
68-
): Mono[Authentication] = {
69+
override def authenticate(authentication: Authentication): Mono[Authentication] = {
70+
println(s"[AUTH DEBUG] Called with credentials: ${authentication.getCredentials}")
6971
try {
7072
val token = authentication.getCredentials.toString
7173
val claim = JwtCirce.decode(token, secretKey, Seq(algo)).get
72-
Mono.just(new UsernamePasswordAuthenticationToken(claim.subject.get, null, Seq[GrantedAuthority]().asJava))
74+
val authorities = Seq(new SimpleGrantedAuthority(ROLE_USER))
75+
val auth = new UsernamePasswordAuthenticationToken(
76+
claim.subject.get,
77+
null,
78+
authorities.asJava
79+
)
80+
Mono.just(auth)
7381
} catch {
7482
case e: Exception => Mono.empty()
7583
}
7684
}
7785

86+
7887
def generateTokenForRegisteredUser(
7988
user: RegisteredUser,
8089
expiration: Instant
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.example.scalaspringexperiment.auth
2+
3+
import org.springframework.http.HttpHeaders
4+
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
5+
import org.springframework.security.core.Authentication
6+
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
7+
import org.springframework.web.server.ServerWebExchange
8+
import reactor.core.publisher.Mono
9+
10+
class JwtServerAuthConverter extends ServerAuthenticationConverter {
11+
override def convert(exchange: ServerWebExchange): Mono[Authentication] = {
12+
val authHeader = Option(exchange.getRequest.getHeaders.getFirst(HttpHeaders.AUTHORIZATION))
13+
authHeader match {
14+
case Some(header) if header.startsWith("Bearer ") =>
15+
val token = header.stripPrefix("Bearer ").trim
16+
Mono.just(new UsernamePasswordAuthenticationToken(token, token))
17+
case _ => Mono.empty()
18+
}
19+
}
20+
}

src/main/scala/com/example/scalaspringexperiment/controller/AuthController.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,26 @@ class AuthController(
9090
@PreAuthorize("permitAll()")
9191
@GetMapping(path = Array("authtest/optional"))
9292
def maybeAuthTest(): Mono[ResponseEntity[Json]] = helper.maybeAuth { ctx =>
93-
ctx.authentication match {
94-
case Some(auth) =>
95-
IO(ResponseEntity.ok(Json.obj(
96-
"message" -> Json.fromString("Signed in"),
97-
)))
98-
case None =>
99-
IO(ResponseEntity.ok(Json.obj(
100-
"message" -> Json.fromString("Not signed in"),
101-
)))
102-
}
93+
IO(ResponseEntity.ok(Json.obj(
94+
"isAuthenticated" -> ctx.authentication.isDefined.asJson,
95+
)))
96+
}
97+
98+
@PreAuthorize("isAuthenticated()")
99+
@GetMapping(path = Array("authtest/required"))
100+
def requiredAuthTest(): Mono[ResponseEntity[Json]] = helper.auth { ctx =>
101+
IO(ResponseEntity.ok(Json.obj()))
102+
}
103+
104+
@PreAuthorize("hasRole('ROLE_USER')")
105+
@GetMapping(path = Array("authtest/user"))
106+
def userAuthTest(): Mono[ResponseEntity[Json]] = helper.auth { ctx =>
107+
IO(ResponseEntity.ok(Json.obj()))
108+
}
109+
110+
@PreAuthorize("hasRole('ROLE_ADMIN')")
111+
@GetMapping(path = Array("authtest/admin"))
112+
def adminAuthTest(): Mono[ResponseEntity[Json]] = helper.auth { ctx =>
113+
IO(ResponseEntity.ok(Json.obj()))
103114
}
104115
}

src/main/scala/com/example/scalaspringexperiment/controller/ControllerErrorHandler.scala

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/main/scala/com/example/scalaspringexperiment/controller/ControllerHelper.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.example.scalaspringexperiment.controller
33
import cats.effect.IO
44
import cats.effect.unsafe.IORuntime
55
import org.springframework.security.core.Authentication
6-
import org.springframework.security.core.context.SecurityContextHolder
6+
import org.springframework.security.core.context.{ReactiveSecurityContextHolder, SecurityContext, SecurityContextHolder}
77
import org.springframework.stereotype.Component
88
import reactor.core.publisher.Mono
99

@@ -21,9 +21,9 @@ case class AuthCtx(
2121
class ControllerHelper(
2222
runtime: IORuntime
2323
) {
24-
24+
2525
given rt: IORuntime = runtime
26-
26+
2727
private given ioToMono[A](): Conversion[IO[A], Mono[A]] with {
2828
def apply(io: IO[A]): Mono[A] = {
2929
Mono.fromFuture(new CompletableFuture[A]().tap { cf =>
@@ -41,13 +41,16 @@ class ControllerHelper(
4141

4242
def maybeAuth[T](
4343
cb: Ctx => IO[T]
44-
): Mono[T] = cb(Ctx(
45-
authentication = Option(SecurityContextHolder.getContext.getAuthentication
46-
)))
44+
): Mono[T] = {
45+
ReactiveSecurityContextHolder.getContext
46+
.map(sc => Option(sc.getAuthentication))
47+
.defaultIfEmpty(None)
48+
.flatMap(auth => cb(Ctx(auth)))
49+
}
4750

4851
def auth[T](
49-
cb: AuthCtx => T
50-
): T = {
52+
cb: AuthCtx => IO[T]
53+
): Mono[T] = {
5154
Option(AuthCtx(
5255
authentication = SecurityContextHolder.getContext.getAuthentication
5356
)) match {

src/test/scala/com/example/scalaspringexperiment/controller/AuthControllerTest.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,62 @@ class AuthControllerTest {
119119
.expectBody()
120120
.consumeWith(TestUtils.responsePrinter)
121121
}
122+
123+
@Test
124+
def maybeAuth_showsAuthenticated_forAuthenticatedRequest(): Unit = {
125+
val (_, _, token) = testUtils.newRegisteredUser()
126+
webTestClient.get()
127+
.uri("/authtest/optional")
128+
.header("Authorization", s"Bearer $token")
129+
.exchange()
130+
.expectStatus().isOk
131+
.expectBody()
132+
.consumeWith(TestUtils.responsePrinter)
133+
.jsonPath("$.isAuthenticated").isEqualTo(true)
134+
}
135+
136+
@Test
137+
def maybeAuth_showsUnauthenticated_forUnauthenticatedRequest(): Unit = {
138+
webTestClient.get()
139+
.uri("/authtest/optional")
140+
.exchange()
141+
.expectStatus().isOk
142+
.expectBody()
143+
.consumeWith(TestUtils.responsePrinter)
144+
.jsonPath("$.isAuthenticated").isEqualTo(false)
145+
}
146+
147+
@Test
148+
def requiredAuth_returnsUnauthorized_forUnauthenticatedRequest(): Unit = {
149+
webTestClient.get()
150+
.uri("/authtest/required")
151+
.exchange()
152+
.expectStatus().isUnauthorized
153+
.expectBody()
154+
.consumeWith(TestUtils.responsePrinter)
155+
}
156+
157+
@Test
158+
def userAuthTest_returnsOk_forAuthenticatedUserRequest(): Unit = {
159+
val (_, _, token) = testUtils.newRegisteredUser()
160+
webTestClient.get()
161+
.uri("/authtest/user")
162+
.header("Authorization", s"Bearer $token")
163+
.exchange()
164+
.expectStatus().isOk
165+
.expectBody()
166+
.consumeWith(TestUtils.responsePrinter)
167+
}
168+
169+
@Test
170+
def adminAuthTest_returnsForbidden_forAuthenticatedNonAdminRequest(): Unit = {
171+
val (_, _, token) = testUtils.newRegisteredUser()
172+
webTestClient.get()
173+
.uri("/authtest/admin")
174+
.header("Authorization", s"Bearer $token")
175+
.exchange()
176+
.expectStatus().isEqualTo(403)
177+
.expectBody()
178+
.consumeWith(TestUtils.responsePrinter)
179+
}
122180
}

src/test/scala/com/example/scalaspringexperiment/test/SpringTestConfig.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@ import javax.sql.DataSource
1111
@Profile(Array("test"))
1212
class SpringTestConfig(
1313
dataSource: DataSource,
14-
jwtAuthManager: JwtAuthManager,
1514
) extends SpringConfig(
1615
dataSource,
17-
jwtAuthManager
1816
) {
1917

2018
// something that exists in prod that we dont want initializing in our tests

0 commit comments

Comments
 (0)