diff --git a/internal/handler/task.go b/internal/handler/task.go index 64def88..25fbc6d 100644 --- a/internal/handler/task.go +++ b/internal/handler/task.go @@ -41,12 +41,13 @@ func New(s service.ITaskService) *TaskHandler { // @Failure 500 {object} error // @Router /api/task [post] func (th *TaskHandler) CreateTask(c *fiber.Ctx) error { + ctx := c.UserContext() ctr := new(dto.CreateTaskRequest) if err := c.BodyParser(ctr); err != nil { utils.ErrorLogger.Println("Failed to parse the body:\n", c.Body()) return err } - t, err := th.s.CreateTask(ctr.Title, ctr.Description) + t, err := th.s.CreateTask(ctx, ctr.Title, ctr.Description) if err != nil { utils.ErrorLogger.Println("Failed to create a new task:\n", err) return c.Status(fiber.StatusInternalServerError).JSON(err) @@ -63,7 +64,8 @@ func (th *TaskHandler) CreateTask(c *fiber.Ctx) error { // @Failure 500 {object} error // @Router /api/task [get] func (th *TaskHandler) GetAllTasks(c *fiber.Ctx) error { - tasks, err := th.s.GetAllTasks() + ctx := c.UserContext() + tasks, err := th.s.GetAllTasks(ctx) if err != nil { utils.ErrorLogger.Println("Failed to get all tasks:\n", err) return c.Status(fiber.StatusInternalServerError).JSON(err) @@ -81,8 +83,9 @@ func (th *TaskHandler) GetAllTasks(c *fiber.Ctx) error { // @Failure 404 {object} error // @Router /api/task/{id} [get] func (th *TaskHandler) GetTaskByID(c *fiber.Ctx) error { + ctx := c.UserContext() id := c.Params("id") - t, err := th.s.GetTaskByID(id) + t, err := th.s.GetTaskByID(ctx, id) if err != nil { utils.ErrorLogger.Printf("Failed to get task with id %s:\n%s", id, err) return c.Status(fiber.StatusNotFound).JSON(err) @@ -102,6 +105,7 @@ func (th *TaskHandler) GetTaskByID(c *fiber.Ctx) error { // @Failure 500 {object} error // @Router /api/task/{id} [patch] func (th *TaskHandler) UpdateTask(c *fiber.Ctx) error { + ctx := c.UserContext() id := c.Params("id") b := new(dto.UpdateTaskRequest) if err := c.BodyParser(b); err != nil { @@ -109,6 +113,7 @@ func (th *TaskHandler) UpdateTask(c *fiber.Ctx) error { return err } t, err := th.s.UpdateTask( + ctx, id, b.Title, b.Description, @@ -131,8 +136,9 @@ func (th *TaskHandler) UpdateTask(c *fiber.Ctx) error { // @Failure 500 {object} error // @Router /api/task/{id} [delete] func (th *TaskHandler) DeleteTask(c *fiber.Ctx) error { + ctx := c.UserContext() id := c.Params("id") - t, err := th.s.DeleteTask(id) + t, err := th.s.DeleteTask(ctx, id) if err != nil { utils.ErrorLogger.Printf("Failed to delete task with id %s:\n%s", id, err) return c.Status(fiber.StatusInternalServerError).JSON(err) diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 6bf5d5c..c7eb03f 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "devtasker/internal/utils" "strings" "time" @@ -39,8 +40,9 @@ func Authorization(c *fiber.Ctx) error { } } - c.Locals("username", claims["username"]) - c.Locals("name", claims["name"]) + ctx := context.WithValue(c.Context(), utils.UsernameKey, claims["username"]) + ctx = context.WithValue(ctx, utils.NameKey, claims["name"]) + c.SetUserContext(ctx) return c.Next() } diff --git a/internal/model/task.go b/internal/model/task.go index 187df31..ccea94e 100644 --- a/internal/model/task.go +++ b/internal/model/task.go @@ -19,4 +19,6 @@ type Task struct { Status TaskStatus `json:"status"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + + UserUsername string } diff --git a/internal/model/user.go b/internal/model/user.go index d32f7e8..25b96f3 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -9,4 +9,6 @@ type User struct { PasswordHash string `json:"password"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + + Tasks []Task `gorm:"foreignKey:UserUsername;references:Username"` } diff --git a/internal/repository/task.go b/internal/repository/task.go index 493bd55..c65e9e1 100644 --- a/internal/repository/task.go +++ b/internal/repository/task.go @@ -7,9 +7,9 @@ import ( ) type ITaskRepository interface { - CreateTask(title, description string) (model.Task, error) + GetAllTasks(username string) ([]model.Task, error) GetTaskByID(id string) (model.Task, error) - GetAllTasks() ([]model.Task, error) + CreateTask(username, title, description string) (model.Task, error) UpdateTask(id, title, description string, status model.TaskStatus) (model.Task, error) DeleteTask(id string) (model.Task, error) } @@ -24,14 +24,13 @@ func New(db *gorm.DB) *TaskRepository { } } -func (tr *TaskRepository) CreateTask(title, description string) (model.Task, error) { - t := model.Task{ - Title: title, - Description: description, - Status: model.Pending, +func (tr *TaskRepository) GetAllTasks(username string) ([]model.Task, error) { + var tasks []model.Task + result := tr.db.Where("user_username = ?", username).Find(&tasks) + if result.Error != nil { + return []model.Task{}, result.Error } - tr.db.Create(&t) - return t, nil + return tasks, nil } func (tr *TaskRepository) GetTaskByID(id string) (model.Task, error) { @@ -43,13 +42,15 @@ func (tr *TaskRepository) GetTaskByID(id string) (model.Task, error) { return task, nil } -func (tr *TaskRepository) GetAllTasks() ([]model.Task, error) { - var tasks []model.Task - result := tr.db.Find(&tasks) - if result.Error != nil { - return []model.Task{}, result.Error +func (tr *TaskRepository) CreateTask(username, title, description string) (model.Task, error) { + t := model.Task{ + Title: title, + Description: description, + Status: model.Pending, + UserUsername: username, } - return tasks, nil + tr.db.Create(&t) + return t, nil } func (tr *TaskRepository) UpdateTask(id, title, description string, status model.TaskStatus) (model.Task, error) { diff --git a/internal/service/task.go b/internal/service/task.go index 9bf27a7..6be8e6f 100644 --- a/internal/service/task.go +++ b/internal/service/task.go @@ -1,17 +1,19 @@ package service import ( + "context" "devtasker/internal/model" "devtasker/internal/repository" + "devtasker/internal/utils" "fmt" ) type ITaskService interface { - CreateTask(title, description string) (model.Task, error) - GetTaskByID(id string) (model.Task, error) - GetAllTasks() ([]model.Task, error) - UpdateTask(id, title, description string, status model.TaskStatus) (model.Task, error) - DeleteTask(id string) (model.Task, error) + GetAllTasks(ctx context.Context) ([]model.Task, error) + GetTaskByID(ctx context.Context, id string) (model.Task, error) + CreateTask(ctx context.Context, title, description string) (model.Task, error) + UpdateTask(ctx context.Context, id, title, description string, status model.TaskStatus) (model.Task, error) + DeleteTask(ctx context.Context, id string) (model.Task, error) } type TaskService struct { @@ -24,37 +26,54 @@ func New(r repository.ITaskRepository) *TaskService { } } -func (ts *TaskService) CreateTask(title, description string) (model.Task, error) { - if title == "" || description == "" { - return model.Task{}, fmt.Errorf("title and description cannot be empty") - } - t, err := ts.r.CreateTask(title, description) +func (ts *TaskService) GetAllTasks(ctx context.Context) ([]model.Task, error) { + username, _ := ctx.Value(utils.UsernameKey).(string) + tasks, err := ts.r.GetAllTasks(username) if err != nil { - return model.Task{}, err + return []model.Task{}, nil } - return t, nil + return tasks, nil } -func (ts *TaskService) GetTaskByID(id string) (model.Task, error) { +func (ts *TaskService) GetTaskByID(ctx context.Context, id string) (model.Task, error) { t, err := ts.r.GetTaskByID(id) if err != nil { return model.Task{}, err } + username, _ := ctx.Value(utils.UsernameKey).(string) + if t.UserUsername != username { + return model.Task{}, fmt.Errorf("you don't have permission to access this task") + } return t, nil } -func (ts *TaskService) GetAllTasks() ([]model.Task, error) { - tasks, err := ts.r.GetAllTasks() +func (ts *TaskService) CreateTask(ctx context.Context, title, description string) (model.Task, error) { + if title == "" || description == "" { + return model.Task{}, fmt.Errorf("title and description cannot be empty") + } + username, _ := ctx.Value(utils.UsernameKey).(string) + t, err := ts.r.CreateTask(username, title, description) if err != nil { - return []model.Task{}, nil + return model.Task{}, err } - return tasks, nil + return t, nil } -func (ts *TaskService) UpdateTask(id, title, description string, status model.TaskStatus) (model.Task, error) { +func (ts *TaskService) UpdateTask(ctx context.Context, id, title, description string, status model.TaskStatus) (model.Task, error) { if title == "" || description == "" || status == "" { return model.Task{}, fmt.Errorf("title and description cannot be empty") } + // Step 1: Get task by ID + task, err := ts.r.GetTaskByID(id) + if err != nil { + return model.Task{}, fmt.Errorf("task not found") + } + // Step 2: Check ownership + username, _ := ctx.Value(utils.UsernameKey).(string) + if task.UserUsername != username { + return model.Task{}, fmt.Errorf("you don't have permission to access this task") + } + // Step 3: Update the task t, err := ts.r.UpdateTask(id, title, description, status) if err != nil { return model.Task{}, err @@ -62,7 +81,18 @@ func (ts *TaskService) UpdateTask(id, title, description string, status model.Ta return t, nil } -func (ts *TaskService) DeleteTask(id string) (model.Task, error) { +func (ts *TaskService) DeleteTask(ctx context.Context, id string) (model.Task, error) { + // Step 1: Get task by ID + task, err := ts.r.GetTaskByID(id) + if err != nil { + return model.Task{}, fmt.Errorf("task not found") + } + // Step 2: Check ownership + username, _ := ctx.Value(utils.UsernameKey).(string) + if task.UserUsername != username { + return model.Task{}, fmt.Errorf("you don't have permission to access this task") + } + // Step 3: Delete the task t, err := ts.r.DeleteTask(id) if err != nil { return model.Task{}, err diff --git a/internal/utils/constant.go b/internal/utils/constant.go new file mode 100644 index 0000000..4277aa0 --- /dev/null +++ b/internal/utils/constant.go @@ -0,0 +1,8 @@ +package utils + +type ContextKey string + +const ( + UsernameKey ContextKey = "username" + NameKey ContextKey = "name" +)