Skip to content

Watch files over LSP #806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func (api *API) PositionEncoding() lsproto.PositionEncodingKind {
return lsproto.PositionEncodingKindUTF8
}

// Client implements ProjectHost.
func (api *API) Client() project.Client {
return nil
}

func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, error) {
params, err := unmarshalPayload(method, payload)
if err != nil {
Expand Down Expand Up @@ -351,7 +356,7 @@ func (api *API) getOrCreateScriptInfo(fileName string, path tspath.Path, scriptK
if !ok {
return nil
}
info = project.NewScriptInfo(fileName, path, scriptKind)
info = project.NewScriptInfo(fileName, path, scriptKind, api.host.FS())
info.SetTextFromDisk(content)
api.scriptInfosMu.Lock()
defer api.scriptInfosMu.Unlock()
Expand Down
22 changes: 21 additions & 1 deletion internal/lsp/lsproto/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type ID struct {
int int32
}

func NewIDString(str string) *ID {
return &ID{str: str}
}

func (id *ID) MarshalJSON() ([]byte, error) {
if id.str != "" {
return json.Marshal(id.str)
Expand All @@ -43,6 +47,13 @@ func (id *ID) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &id.int)
}

func (id *ID) TryInt() (int32, bool) {
if id == nil || id.str != "" {
return 0, false
}
return id.int, true
}

func (id *ID) MustInt() int32 {
if id.str != "" {
panic("ID is not an integer")
Expand All @@ -54,11 +65,20 @@ func (id *ID) MustInt() int32 {

type RequestMessage struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
ID *ID `json:"id"`
ID *ID `json:"id,omitempty"`
Method Method `json:"method"`
Params any `json:"params"`
}

func NewRequestMessage(method Method, id *ID, params any) *RequestMessage {
return &RequestMessage{
JSONRPC: JSONRPCVersion{},
ID: id,
Method: method,
Params: params,
}
}

func (r *RequestMessage) UnmarshalJSON(data []byte) error {
var raw struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
Expand Down
118 changes: 112 additions & 6 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ func NewServer(opts *ServerOptions) *Server {
newLine: opts.NewLine,
fs: opts.FS,
defaultLibraryPath: opts.DefaultLibraryPath,
watchers: make(map[project.WatcherHandle]struct{}),
}
}

var _ project.ServiceHost = (*Server)(nil)
var (
_ project.ServiceHost = (*Server)(nil)
_ project.Client = (*Server)(nil)
)

type Server struct {
r *lsproto.BaseReader
w *lsproto.BaseWriter

stderr io.Writer

clientSeq int32
requestMethod string
requestTime time.Time

Expand All @@ -62,36 +67,95 @@ type Server struct {
initializeParams *lsproto.InitializeParams
positionEncoding lsproto.PositionEncodingKind

watcheEnabled bool
watcherID int
watchers map[project.WatcherHandle]struct{}
logger *project.Logger
projectService *project.Service
converters *ls.Converters
}

// FS implements project.ProjectServiceHost.
// FS implements project.ServiceHost.
func (s *Server) FS() vfs.FS {
return s.fs
}

// DefaultLibraryPath implements project.ProjectServiceHost.
// DefaultLibraryPath implements project.ServiceHost.
func (s *Server) DefaultLibraryPath() string {
return s.defaultLibraryPath
}

// GetCurrentDirectory implements project.ProjectServiceHost.
// GetCurrentDirectory implements project.ServiceHost.
func (s *Server) GetCurrentDirectory() string {
return s.cwd
}

// NewLine implements project.ProjectServiceHost.
// NewLine implements project.ServiceHost.
func (s *Server) NewLine() string {
return s.newLine.GetNewLineCharacter()
}

// Trace implements project.ProjectServiceHost.
// Trace implements project.ServiceHost.
func (s *Server) Trace(msg string) {
s.Log(msg)
}

// Client implements project.ServiceHost.
func (s *Server) Client() project.Client {
if !s.watcheEnabled {
return nil
}
return s
}

// WatchFiles implements project.Client.
func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) {
watcherId := fmt.Sprintf("watcher-%d", s.watcherID)
if err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{
Registrations: []*lsproto.Registration{
{
Id: watcherId,
Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles),
RegisterOptions: ptrTo(any(lsproto.DidChangeWatchedFilesRegistrationOptions{
Watchers: watchers,
})),
},
},
}); err != nil {
return "", fmt.Errorf("failed to register file watcher: %w", err)
}

handle := project.WatcherHandle(watcherId)
s.watchers[handle] = struct{}{}
s.watcherID++
return handle, nil
}

// UnwatchFiles implements project.Client.
func (s *Server) UnwatchFiles(handle project.WatcherHandle) error {
if _, ok := s.watchers[handle]; ok {
if err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{
Unregisterations: []*lsproto.Unregistration{
{
Id: string(handle),
Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles),
},
},
}); err != nil {
return fmt.Errorf("failed to unregister file watcher: %w", err)
}
delete(s.watchers, handle)
return nil
}

return fmt.Errorf("no file watcher exists with ID %s", handle)
}

// PublishDiagnostics implements project.Client.
func (s *Server) PublishDiagnostics(params *lsproto.PublishDiagnosticsParams) error {
return s.sendNotification(lsproto.MethodTextDocumentPublishDiagnostics, params)
}

func (s *Server) Run() error {
for {
req, err := s.read()
Expand All @@ -105,6 +169,11 @@ func (s *Server) Run() error {
return err
}

// TODO: handle response messages
if req == nil {
continue
}

if s.initializeParams == nil {
if req.Method == lsproto.MethodInitialize {
if err := s.handleInitialize(req); err != nil {
Expand Down Expand Up @@ -132,12 +201,37 @@ func (s *Server) read() (*lsproto.RequestMessage, error) {

req := &lsproto.RequestMessage{}
if err := json.Unmarshal(data, req); err != nil {
res := &lsproto.ResponseMessage{}
if err = json.Unmarshal(data, res); err == nil {
// !!! TODO: handle response
return nil, nil
}
return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err)
}

return req, nil
}

func (s *Server) sendRequest(method lsproto.Method, params any) error {
s.clientSeq++
id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq))
req := lsproto.NewRequestMessage(method, id, params)
data, err := json.Marshal(req)
if err != nil {
return err
}
return s.w.Write(data)
}

func (s *Server) sendNotification(method lsproto.Method, params any) error {
req := lsproto.NewRequestMessage(method, nil /*id*/, params)
data, err := json.Marshal(req)
if err != nil {
return err
}
return s.w.Write(data)
}

func (s *Server) sendResult(id *lsproto.ID, result any) error {
return s.sendResponse(&lsproto.ResponseMessage{
ID: id,
Expand Down Expand Up @@ -189,6 +283,8 @@ func (s *Server) handleMessage(req *lsproto.RequestMessage) error {
return s.handleDidSave(req)
case *lsproto.DidCloseTextDocumentParams:
return s.handleDidClose(req)
case *lsproto.DidChangeWatchedFilesParams:
return s.handleDidChangeWatchedFiles(req)
case *lsproto.DocumentDiagnosticParams:
return s.handleDocumentDiagnostic(req)
case *lsproto.HoverParams:
Expand Down Expand Up @@ -262,9 +358,14 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error {
}

func (s *Server) handleInitialized(req *lsproto.RequestMessage) error {
if s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles != nil && *s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles.DynamicRegistration {
s.watcheEnabled = true
}

s.logger = project.NewLogger([]io.Writer{s.stderr}, "" /*file*/, project.LogLevelVerbose)
s.projectService = project.NewService(s, project.ServiceOptions{
Logger: s.logger,
WatchEnabled: s.watcheEnabled,
PositionEncoding: s.positionEncoding,
})

Expand Down Expand Up @@ -322,6 +423,11 @@ func (s *Server) handleDidClose(req *lsproto.RequestMessage) error {
return nil
}

func (s *Server) handleDidChangeWatchedFiles(req *lsproto.RequestMessage) error {
params := req.Params.(*lsproto.DidChangeWatchedFilesParams)
return s.projectService.OnWatchedFilesChanged(params.Changes)
}

func (s *Server) handleDocumentDiagnostic(req *lsproto.RequestMessage) error {
params := req.Params.(*lsproto.DocumentDiagnosticParams)
file, project := s.getFileAndProject(params.TextDocument.Uri)
Expand Down
6 changes: 3 additions & 3 deletions internal/project/documentregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (r *DocumentRegistry) getDocumentWorker(
if entry, ok := r.documents.Load(key); ok {
// We have an entry for this file. However, it may be for a different version of
// the script snapshot. If so, update it appropriately.
if entry.sourceFile.Version != scriptInfo.version {
if entry.sourceFile.Version != scriptInfo.Version() {
sourceFile := parser.ParseSourceFile(scriptInfo.fileName, scriptInfo.path, scriptInfo.text, scriptTarget, scanner.JSDocParsingModeParseAll)
sourceFile.Version = scriptInfo.version
sourceFile.Version = scriptInfo.Version()
entry.mu.Lock()
defer entry.mu.Unlock()
entry.sourceFile = sourceFile
Expand All @@ -104,7 +104,7 @@ func (r *DocumentRegistry) getDocumentWorker(
} else {
// Have never seen this file with these settings. Create a new source file for it.
sourceFile := parser.ParseSourceFile(scriptInfo.fileName, scriptInfo.path, scriptInfo.text, scriptTarget, scanner.JSDocParsingModeParseAll)
sourceFile.Version = scriptInfo.version
sourceFile.Version = scriptInfo.Version()
entry, _ := r.documents.LoadOrStore(key, &registryEntry{
sourceFile: sourceFile,
refCount: 0,
Expand Down
15 changes: 14 additions & 1 deletion internal/project/host.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
package project

import "github.com/microsoft/typescript-go/internal/vfs"
import (
"github.com/microsoft/typescript-go/internal/lsp/lsproto"
"github.com/microsoft/typescript-go/internal/vfs"
)

type WatcherHandle string

type Client interface {
WatchFiles(watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error)
UnwatchFiles(handle WatcherHandle) error
PublishDiagnostics(params *lsproto.PublishDiagnosticsParams) error
}

type ServiceHost interface {
FS() vfs.FS
DefaultLibraryPath() string
GetCurrentDirectory() string
NewLine() string

Client() Client
}
Loading